1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <cmath>
18#include <cstddef>
19#include <cstdint>
20#include <cstring>
21#include <memory>
22#include <vector>
23
24#include "tensorflow/lite/c/builtin_op_data.h"
25#include "tensorflow/lite/c/common.h"
26#include "tensorflow/lite/kernels/cpu_backend_context.h"
27#include "tensorflow/lite/kernels/internal/compatibility.h"
28#include "tensorflow/lite/kernels/internal/kernel_utils.h"
29#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
30#include "tensorflow/lite/kernels/internal/quantization_util.h"
31#include "tensorflow/lite/kernels/internal/tensor.h"
32#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
33#include "tensorflow/lite/kernels/internal/tensor_utils.h"
34#include "tensorflow/lite/kernels/internal/types.h"
35#include "tensorflow/lite/kernels/kernel_util.h"
36#include "tensorflow/lite/kernels/lstm_eval.h"
37#include "tensorflow/lite/kernels/lstm_shared.h"
38
39namespace tflite {
40namespace ops {
41namespace builtin {
42namespace lstm {
43
44struct OpData {
45 // Which kernel type to use. Full kernel (24 inputs) or basic kernel (5
46 // inputs).
47 // Please note the 20-input full kernel is deprecated and only kept
48 // here for backward compatibility.
49 TfLiteLSTMKernelType kernel_type;
50
51 // If the lstm is layer norm.
52 bool use_layer_norm;
53
54 // These fields are only used by full kernel.
55 int scratch_tensor_index;
56 lstm_eval::IntegerLstmParameter integer_lstm_param;
57 bool compute_row_sums;
58
59 // Only used for sparse hybrid lstm kernels.
60 int ledger_index;
61 bool ledger_initialized;
62};
63
64namespace full {
65namespace {
66
67// Named temporary tensors.
68enum HybridTemporaryTensor {
69 kScratchBuffer = 0,
70 kInputQuantized = 1,
71 kOutputStateQuantized = 2,
72 kCellStateQuantized = 3,
73 kInputScalingFactors = 4,
74 kOutputStateScalingFactors = 5,
75 kProductScalingFactors = 6,
76 kRecoveredCellWeights = 7,
77 kAccumScratch = 8,
78 kInputZeroPoints = 9,
79 kOutputStateZeroPoints = 10,
80 kRowSums = 11,
81 kNumHybridTemporaryTensors = 12,
82};
83
84constexpr int kLedgersToAdd = 9;
85constexpr int kInputToInputWeightsLedgerOffset = 0;
86constexpr int kInputToForgetWeightsLedgerOffset = 1;
87constexpr int kInputToCellWeightsLedgerOffset = 2;
88constexpr int kInputToOutputWeightsLedgerOffset = 3;
89constexpr int kRecurrentToInputWeightsLedgerOffset = 4;
90constexpr int kRecurrentToForgetWeightsLedgerOffset = 5;
91constexpr int kRecurrentToCellWeightsLedgerOffset = 6;
92constexpr int kRecurrentToOutputWeightsLedgerOffset = 7;
93constexpr int kProjectionWeightsLedgerOffset = 8;
94
95TfLiteStatus make_ledger(const TfLiteSparsity* sparsity, TfLiteContext* context,
96 TfLiteTensor* ledger) {
97 ledger->type = kTfLiteUInt8;
98 ledger->name = "Lstm_ledger";
99 ledger->allocation_type = kTfLiteArenaRwPersistent;
100 if (sparsity == nullptr) {
101 return kTfLiteOk;
102 }
103 TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1);
104 ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size +
105 sparsity->dim_metadata[1].array_segments->size - 1;
106 return context->ResizeTensor(context, ledger, ledger_size);
107}
108
109TfLiteStatus copy_ledger(const TfLiteSparsity* sparsity, TfLiteTensor* ledger) {
110 if (sparsity == nullptr) {
111 return kTfLiteOk;
112 }
113
114 const auto* array_segments = sparsity->dim_metadata[1].array_segments;
115 const auto* array_indices = sparsity->dim_metadata[1].array_indices;
116 uint8_t* output_data = GetTensorData<uint8_t>(ledger);
117 int output_data_ptr = 0;
118
119 for (int i = 0; i < array_segments->size - 1; i++) {
120 int row_start = array_segments->data[i];
121 int row_end = array_segments->data[i + 1];
122 if (row_end - row_start > UINT8_MAX) {
123 return kTfLiteError;
124 }
125 // Copy num of non-zero blocks in row i.
126 output_data[output_data_ptr] = static_cast<uint8_t>(row_end - row_start);
127 output_data_ptr++;
128
129 for (int j = row_start; j < row_end; j++) {
130 if (array_indices->data[j] > UINT8_MAX) {
131 return kTfLiteError;
132 }
133 // Copy indices of non-zero blocks in row i.
134 output_data[output_data_ptr] =
135 static_cast<uint8_t>(array_indices->data[j]);
136 output_data_ptr++;
137 }
138 }
139 return kTfLiteOk;
140}
141
142TfLiteStatus PopulateQuantizedLstmParams8x8_16(
143 TfLiteContext* context, TfLiteNode* node,
144 lstm_eval::IntegerLstmParameter* integer_lstm_param) {
145 // Calculate quantized clip for projection and cell.
146 const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
147 const float cell_clip = params->cell_clip;
148 const float proj_clip = params->proj_clip;
149
150 const TfLiteTensor* cell_state =
151 GetVariableInput(context, node, kCellStateTensor);
152 TF_LITE_ENSURE(context, cell_state != nullptr);
153 TfLiteTensor* output_tensor;
154 TF_LITE_ENSURE_OK(
155 context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
156
157 auto* cell_state_params =
158 static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
159 auto* proj_params = static_cast<TfLiteAffineQuantization*>(
160 output_tensor->quantization.params);
161 if (cell_clip > 0.0) {
162 integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min(
163 std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
164 32767.0f));
165 } else {
166 integer_lstm_param->quantized_cell_clip = 0;
167 }
168 if (proj_clip > 0.0) {
169 integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min(
170 std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f));
171 } else {
172 integer_lstm_param->quantized_proj_clip = 0;
173 }
174
175 // Calculate effective scales.
176 OpData* op_data = static_cast<OpData*>(node->user_data);
177 const bool use_layer_norm = op_data->use_layer_norm;
178
179 const TfLiteTensor* input;
180 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
181
182 const TfLiteTensor* input_to_input_weights =
183 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
184 const TfLiteTensor* input_to_forget_weights;
185 TF_LITE_ENSURE_OK(context,
186 GetInputSafe(context, node, kInputToForgetWeightsTensor,
187 &input_to_forget_weights));
188 const TfLiteTensor* input_to_cell_weights;
189 TF_LITE_ENSURE_OK(context,
190 GetInputSafe(context, node, kInputToCellWeightsTensor,
191 &input_to_cell_weights));
192 const TfLiteTensor* input_to_output_weights;
193 TF_LITE_ENSURE_OK(context,
194 GetInputSafe(context, node, kInputToOutputWeightsTensor,
195 &input_to_output_weights));
196
197 const TfLiteTensor* recurrent_to_input_weights =
198 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
199 const TfLiteTensor* recurrent_to_forget_weights;
200 TF_LITE_ENSURE_OK(context,
201 GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
202 &recurrent_to_forget_weights));
203 const TfLiteTensor* recurrent_to_cell_weights;
204 TF_LITE_ENSURE_OK(context,
205 GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
206 &recurrent_to_cell_weights));
207 const TfLiteTensor* recurrent_to_output_weights;
208 TF_LITE_ENSURE_OK(context,
209 GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
210 &recurrent_to_output_weights));
211
212 const TfLiteTensor* cell_to_input_weights =
213 GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
214 const TfLiteTensor* cell_to_forget_weights =
215 GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
216 const TfLiteTensor* cell_to_output_weights =
217 GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
218
219 const TfLiteTensor* input_layer_norm_coefficients =
220 GetOptionalInputTensor(context, node, kInputLayerNormCoefficientsTensor);
221 const TfLiteTensor* forget_layer_norm_coefficients =
222 GetOptionalInputTensor(context, node, kForgetLayerNormCoefficientsTensor);
223 const TfLiteTensor* cell_layer_norm_coefficients =
224 GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor);
225 const TfLiteTensor* output_layer_norm_coefficients =
226 GetOptionalInputTensor(context, node, kOutputLayerNormCoefficientsTensor);
227
228 const TfLiteTensor* projection_weights =
229 GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
230
231 TfLiteTensor* output_state =
232 GetVariableInput(context, node, kOutputStateTensor);
233 TF_LITE_ENSURE(context, output_state != nullptr);
234
235 // Since we have already checked that weights are all there or none, we can
236 // check the existence of only one to get the condition.
237 const bool use_cifg = (input_to_input_weights == nullptr);
238 const bool use_peephole = (cell_to_output_weights != nullptr);
239 const bool use_projection = (projection_weights != nullptr);
240
241 // Get intermediate scales and zero points.
242 std::vector<float> intermediate_scale;
243 std::vector<int32> intermediate_zp;
244 for (int i = 0; i < 4; ++i) {
245 if (use_layer_norm) {
246 TfLiteTensor* intermediate;
247 TF_LITE_ENSURE_OK(context,
248 GetIntermediatesSafe(context, node, i, &intermediate));
249 auto* params = static_cast<TfLiteAffineQuantization*>(
250 intermediate->quantization.params);
251 intermediate_scale.push_back(params->scale->data[0]);
252 intermediate_zp.push_back(params->zero_point->data[0]);
253 } else {
254 // Q3.12 for activation functions.
255 intermediate_scale.push_back(std::pow(2, -12));
256 intermediate_zp.push_back(0);
257 }
258 }
259 // In the absence of projection, hidden becomes output and this intermediate
260 // is ignored.
261 TfLiteTensor* hidden;
262 TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
263 auto* hidden_params =
264 static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
265 intermediate_scale.push_back(hidden_params->scale->data[0]);
266 intermediate_zp.push_back(hidden_params->zero_point->data[0]);
267
268 // Scales.
269 const float default_scale = 1.0;
270 float input_scale = default_scale;
271 float input_to_input_weight_scale = default_scale;
272 float recurrent_to_input_weight_scale = default_scale;
273 float cell_to_input_weight_scale = default_scale;
274 float input_to_forget_weight_scale = default_scale;
275 float recurrent_to_forget_weight_scale = default_scale;
276 float cell_to_forget_weight_scale = default_scale;
277 float input_to_cell_weight_scale = default_scale;
278 float recurrent_to_cell_weight_scale = default_scale;
279 float input_to_output_weight_scale = default_scale;
280 float recurrent_to_output_weight_scale = default_scale;
281 float cell_to_output_weight_scale = default_scale;
282 float projection_weight_scale = default_scale;
283 float layer_norm_input_scale = default_scale;
284 float layer_norm_forget_scale = default_scale;
285 float layer_norm_cell_scale = default_scale;
286 float layer_norm_output_scale = default_scale;
287 float output_state_scale = default_scale;
288 int cell_scale = 1;
289
290 // Effective scales.
291 float effective_input_to_input_scale = default_scale;
292 float effective_recurrent_to_input_scale = default_scale;
293 float effective_cell_to_input_scale = default_scale;
294 float effective_input_to_forget_scale = default_scale;
295 float effective_recurrent_to_forget_scale = default_scale;
296 float effective_cell_to_forget_scale = default_scale;
297 float effective_input_to_cell_scale = default_scale;
298 float effective_recurrent_to_cell_scale = default_scale;
299 float effective_input_to_output_scale = default_scale;
300 float effective_recurrent_to_output_scale = default_scale;
301 float effective_cell_to_output_scale = default_scale;
302 float effective_proj_scale = default_scale;
303 float effective_hidden_scale = default_scale;
304
305 // Populate scales.
306 if (!use_cifg) {
307 input_to_input_weight_scale = input_to_input_weights->params.scale;
308 recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
309 }
310
311 if (use_peephole) {
312 if (!use_cifg) {
313 cell_to_input_weight_scale = cell_to_input_weights->params.scale;
314 }
315 cell_to_forget_weight_scale = cell_to_forget_weights->params.scale;
316 cell_to_output_weight_scale = cell_to_output_weights->params.scale;
317 }
318
319 if (use_layer_norm) {
320 if (!use_cifg) {
321 layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
322 }
323 layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale;
324 layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale;
325 layer_norm_output_scale = output_layer_norm_coefficients->params.scale;
326 }
327
328 if (use_projection) {
329 projection_weight_scale = projection_weights->params.scale;
330 }
331 output_state_scale = output_state->params.scale;
332
333 input_to_forget_weight_scale = input_to_forget_weights->params.scale;
334 input_to_cell_weight_scale = input_to_cell_weights->params.scale;
335 input_to_output_weight_scale = input_to_output_weights->params.scale;
336 recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
337 recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
338 recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
339
340 // Check cell state (already used above)
341 TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale));
342 TF_LITE_ENSURE(context, cell_scale <= -9);
343 integer_lstm_param->cell_scale = cell_scale;
344 input_scale = input->params.scale;
345
346 // Calculate effective scales.
347 if (!use_cifg) {
348 effective_input_to_input_scale =
349 input_to_input_weight_scale * input_scale / intermediate_scale[0];
350 effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
351 output_state_scale /
352 intermediate_scale[0];
353 }
354 effective_input_to_forget_scale =
355 input_to_forget_weight_scale * input_scale / intermediate_scale[1];
356 effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
357 output_state_scale /
358 intermediate_scale[1];
359
360 effective_input_to_cell_scale =
361 input_to_cell_weight_scale * input_scale / intermediate_scale[2];
362 effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
363 output_state_scale /
364 intermediate_scale[2];
365
366 effective_input_to_output_scale =
367 input_to_output_weight_scale * input_scale / intermediate_scale[3];
368 effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
369 output_state_scale /
370 intermediate_scale[3];
371
372 effective_hidden_scale =
373 std::pow(2, -15) / intermediate_scale[4] * std::pow(2, -15);
374
375 effective_proj_scale =
376 projection_weight_scale * intermediate_scale[4] / output_state_scale;
377
378 if (use_peephole) {
379 if (!use_cifg) {
380 effective_cell_to_input_scale = std::pow(2, cell_scale) * // NOLINT
381 cell_to_input_weight_scale /
382 intermediate_scale[0];
383 }
384 effective_cell_to_forget_scale = std::pow(2, cell_scale) * // NOLINT
385 cell_to_forget_weight_scale /
386 intermediate_scale[1];
387 effective_cell_to_output_scale = std::pow(2, cell_scale) * // NOLINT
388 cell_to_output_weight_scale /
389 intermediate_scale[3];
390 }
391
392 // Decompose scales.
393 QuantizeMultiplier(effective_input_to_input_scale,
394 &integer_lstm_param->effective_input_to_input_scale_a,
395 &integer_lstm_param->effective_input_to_input_scale_b);
396 QuantizeMultiplier(effective_recurrent_to_input_scale,
397 &integer_lstm_param->effective_recurrent_to_input_scale_a,
398 &integer_lstm_param->effective_recurrent_to_input_scale_b);
399 QuantizeMultiplier(effective_cell_to_input_scale,
400 &integer_lstm_param->effective_cell_to_input_scale_a,
401 &integer_lstm_param->effective_cell_to_input_scale_b);
402 QuantizeMultiplier(effective_input_to_forget_scale,
403 &integer_lstm_param->effective_input_to_forget_scale_a,
404 &integer_lstm_param->effective_input_to_forget_scale_b);
405 QuantizeMultiplier(
406 effective_recurrent_to_forget_scale,
407 &integer_lstm_param->effective_recurrent_to_forget_scale_a,
408 &integer_lstm_param->effective_recurrent_to_forget_scale_b);
409 QuantizeMultiplier(effective_cell_to_forget_scale,
410 &integer_lstm_param->effective_cell_to_forget_scale_a,
411 &integer_lstm_param->effective_cell_to_forget_scale_b);
412 QuantizeMultiplier(effective_input_to_cell_scale,
413 &integer_lstm_param->effective_input_to_cell_scale_a,
414 &integer_lstm_param->effective_input_to_cell_scale_b);
415 QuantizeMultiplier(effective_recurrent_to_cell_scale,
416 &integer_lstm_param->effective_recurrent_to_cell_scale_a,
417 &integer_lstm_param->effective_recurrent_to_cell_scale_b);
418 QuantizeMultiplier(effective_input_to_output_scale,
419 &integer_lstm_param->effective_input_to_output_scale_a,
420 &integer_lstm_param->effective_input_to_output_scale_b);
421 QuantizeMultiplier(
422 effective_recurrent_to_output_scale,
423 &integer_lstm_param->effective_recurrent_to_output_scale_a,
424 &integer_lstm_param->effective_recurrent_to_output_scale_b);
425 QuantizeMultiplier(effective_cell_to_output_scale,
426 &integer_lstm_param->effective_cell_to_output_scale_a,
427 &integer_lstm_param->effective_cell_to_output_scale_b);
428 QuantizeMultiplier(effective_proj_scale,
429 &integer_lstm_param->effective_proj_scale_a,
430 &integer_lstm_param->effective_proj_scale_b);
431 QuantizeMultiplier(effective_hidden_scale,
432 &integer_lstm_param->effective_hidden_scale_a,
433 &integer_lstm_param->effective_hidden_scale_b);
434 QuantizeMultiplier(layer_norm_input_scale,
435 &integer_lstm_param->layer_norm_input_scale_a,
436 &integer_lstm_param->layer_norm_input_scale_b);
437 QuantizeMultiplier(layer_norm_forget_scale,
438 &integer_lstm_param->layer_norm_forget_scale_a,
439 &integer_lstm_param->layer_norm_forget_scale_b);
440 QuantizeMultiplier(layer_norm_cell_scale,
441 &integer_lstm_param->layer_norm_cell_scale_a,
442 &integer_lstm_param->layer_norm_cell_scale_b);
443 QuantizeMultiplier(layer_norm_output_scale,
444 &integer_lstm_param->layer_norm_output_scale_a,
445 &integer_lstm_param->layer_norm_output_scale_b);
446
447 integer_lstm_param->hidden_zp = intermediate_zp[4];
448
449 // 10000 is used to make sure the kernel logic does not overflow.
450 if (!use_cifg) {
451 integer_lstm_param->input_variance_guard =
452 std::max(1, static_cast<int32_t>(10000 * layer_norm_input_scale));
453 }
454 integer_lstm_param->forget_variance_guard =
455 std::max(1, static_cast<int32_t>(10000 * layer_norm_forget_scale));
456 integer_lstm_param->cell_variance_guard =
457 std::max(1, static_cast<int32_t>(10000 * layer_norm_cell_scale));
458 integer_lstm_param->output_variance_guard =
459 std::max(1, static_cast<int32_t>(10000 * layer_norm_output_scale));
460
461 return kTfLiteOk;
462}
463
464TfLiteStatus PopulateQuantizedLstmParams8x8_8(
465 TfLiteContext* context, TfLiteNode* node,
466 lstm_eval::IntegerLstmParameter* integer_lstm_param) {
467 // Get all tensors.
468 const TfLiteTensor* input;
469 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
470 const TfLiteTensor* input_to_input_weights =
471 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
472 const TfLiteTensor* input_to_forget_weights;
473 TF_LITE_ENSURE_OK(context,
474 GetInputSafe(context, node, kInputToForgetWeightsTensor,
475 &input_to_forget_weights));
476 const TfLiteTensor* input_to_cell_weights;
477 TF_LITE_ENSURE_OK(context,
478 GetInputSafe(context, node, kInputToCellWeightsTensor,
479 &input_to_cell_weights));
480 const TfLiteTensor* input_to_output_weights;
481 TF_LITE_ENSURE_OK(context,
482 GetInputSafe(context, node, kInputToOutputWeightsTensor,
483 &input_to_output_weights));
484
485 const TfLiteTensor* recurrent_to_input_weights =
486 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
487 const TfLiteTensor* recurrent_to_forget_weights;
488 TF_LITE_ENSURE_OK(context,
489 GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
490 &recurrent_to_forget_weights));
491 const TfLiteTensor* recurrent_to_cell_weights;
492 TF_LITE_ENSURE_OK(context,
493 GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
494 &recurrent_to_cell_weights));
495 const TfLiteTensor* recurrent_to_output_weights;
496 TF_LITE_ENSURE_OK(context,
497 GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
498 &recurrent_to_output_weights));
499
500 const TfLiteTensor* cell_to_input_weights =
501 GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
502 const TfLiteTensor* cell_to_forget_weights =
503 GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
504 const TfLiteTensor* cell_to_output_weights =
505 GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
506
507 const TfLiteTensor* input_layer_norm_coefficients =
508 GetOptionalInputTensor(context, node, kInputLayerNormCoefficientsTensor);
509 const TfLiteTensor* forget_layer_norm_coefficients =
510 GetOptionalInputTensor(context, node, kForgetLayerNormCoefficientsTensor);
511 const TfLiteTensor* cell_layer_norm_coefficients =
512 GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor);
513 const TfLiteTensor* output_layer_norm_coefficients =
514 GetOptionalInputTensor(context, node, kOutputLayerNormCoefficientsTensor);
515
516 const TfLiteTensor* input_gate_bias =
517 GetOptionalInputTensor(context, node, kInputGateBiasTensor);
518 const TfLiteTensor* forget_gate_bias;
519 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
520 &forget_gate_bias));
521 const TfLiteTensor* cell_gate_bias;
522 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
523 &cell_gate_bias));
524 const TfLiteTensor* output_gate_bias;
525 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
526 &output_gate_bias));
527
528 const TfLiteTensor* projection_weights =
529 GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
530 const TfLiteTensor* projection_bias =
531 GetOptionalInputTensor(context, node, kProjectionBiasTensor);
532
533 TfLiteTensor* output_state =
534 GetVariableInput(context, node, kOutputStateTensor);
535 TF_LITE_ENSURE(context, output_state != nullptr);
536 TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
537 TF_LITE_ENSURE(context, cell_state != nullptr);
538
539 // Since we have already checked that weights are all there or none, we can
540 // check the existence of only one to get the condition.
541 const bool use_cifg = (input_to_input_weights == nullptr);
542 const bool use_peephole = (cell_to_output_weights != nullptr);
543 const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr);
544 const bool use_projection = (projection_weights != nullptr);
545
546 // Weights and states.
547 int8_t* input_to_input_weight_ptr = nullptr;
548 int8_t* recurrent_to_input_weight_ptr = nullptr;
549 int8_t* cell_to_input_weight_ptr = nullptr;
550 int8_t* input_to_forget_weight_ptr = nullptr;
551 int8_t* recurrent_to_forget_weight_ptr = nullptr;
552 int8_t* cell_to_forget_weight_ptr = nullptr;
553 int8_t* input_to_cell_weight_ptr = nullptr;
554 int8_t* recurrent_to_cell_weight_ptr = nullptr;
555 int8_t* input_to_output_weight_ptr = nullptr;
556 int8_t* recurrent_to_output_weight_ptr = nullptr;
557 int8_t* cell_to_output_weight_ptr = nullptr;
558 int8_t* projection_weight_ptr = nullptr;
559 int16_t* layer_norm_input_weight_ptr = nullptr;
560 int16_t* layer_norm_forget_weight_ptr = nullptr;
561 int16_t* layer_norm_cell_weight_ptr = nullptr;
562 int16_t* layer_norm_output_weight_ptr = nullptr;
563 int32_t* input_gate_bias_ptr = nullptr;
564 int32_t* forget_gate_bias_ptr = nullptr;
565 int32_t* cell_gate_bias_ptr = nullptr;
566 int32_t* output_gate_bias_ptr = nullptr;
567 int32_t* projection_bias_ptr = nullptr;
568 int16_t* cell_ptr = nullptr;
569 int8_t* output_state_ptr = nullptr;
570
571 // Scales.
572 const float default_scale = 1.0;
573 float input_scale = default_scale;
574 float input_to_input_weight_scale = default_scale;
575 float recurrent_to_input_weight_scale = default_scale;
576 float cell_to_input_weight_scale = default_scale;
577 float input_to_forget_weight_scale = default_scale;
578 float recurrent_to_forget_weight_scale = default_scale;
579 float cell_to_forget_weight_scale = default_scale;
580 float input_to_cell_weight_scale = default_scale;
581 float recurrent_to_cell_weight_scale = default_scale;
582 float input_to_output_weight_scale = default_scale;
583 float recurrent_to_output_weight_scale = default_scale;
584 float cell_to_output_weight_scale = default_scale;
585 float projection_weight_scale = default_scale;
586 float layer_norm_input_scale = default_scale;
587 float layer_norm_forget_scale = default_scale;
588 float layer_norm_cell_scale = default_scale;
589 float layer_norm_output_scale = default_scale;
590 float output_state_scale = default_scale;
591
592 // Effective scales.
593 float effective_input_to_input_scale = default_scale;
594 float effective_recurrent_to_input_scale = default_scale;
595 float effective_cell_to_input_scale = default_scale;
596 float effective_input_to_forget_scale = default_scale;
597 float effective_recurrent_to_forget_scale = default_scale;
598 float effective_cell_to_forget_scale = default_scale;
599 float effective_input_to_cell_scale = default_scale;
600 float effective_recurrent_to_cell_scale = default_scale;
601 float effective_input_to_output_scale = default_scale;
602 float effective_recurrent_to_output_scale = default_scale;
603 float effective_cell_to_output_scale = default_scale;
604 float effective_proj_scale = default_scale;
605
606 // Zero points
607 int input_zp = 0;
608 int output_state_zp = 0;
609
610 // Populate all the values.
611 if (!use_cifg) {
612 input_to_input_weight_ptr = input_to_input_weights->data.int8;
613 recurrent_to_input_weight_ptr = recurrent_to_input_weights->data.int8;
614 input_gate_bias_ptr = input_gate_bias->data.i32;
615 input_to_input_weight_scale = input_to_input_weights->params.scale;
616 recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
617 }
618
619 if (use_peephole) {
620 if (!use_cifg) {
621 cell_to_input_weight_ptr = cell_to_input_weights->data.int8;
622 cell_to_input_weight_scale = cell_to_input_weights->params.scale;
623 }
624 cell_to_forget_weight_ptr = cell_to_forget_weights->data.int8;
625 cell_to_output_weight_ptr = cell_to_output_weights->data.int8;
626 cell_to_forget_weight_scale = cell_to_forget_weights->params.scale;
627 cell_to_output_weight_scale = cell_to_output_weights->params.scale;
628 }
629
630 if (is_layer_norm_lstm) {
631 if (!use_cifg) {
632 layer_norm_input_weight_ptr = input_layer_norm_coefficients->data.i16;
633 layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
634 }
635 layer_norm_forget_weight_ptr = forget_layer_norm_coefficients->data.i16;
636 layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale;
637 layer_norm_cell_weight_ptr = cell_layer_norm_coefficients->data.i16;
638 layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale;
639 layer_norm_output_weight_ptr = output_layer_norm_coefficients->data.i16;
640 layer_norm_output_scale = output_layer_norm_coefficients->params.scale;
641 }
642
643 if (use_projection) {
644 projection_weight_ptr = projection_weights->data.int8;
645 projection_weight_scale = projection_weights->params.scale;
646 if (projection_bias) {
647 projection_bias_ptr = projection_bias->data.i32;
648 }
649 }
650 output_state_scale = output_state->params.scale;
651
652 input_to_forget_weight_ptr = input_to_forget_weights->data.int8;
653 input_to_forget_weight_scale = input_to_forget_weights->params.scale;
654 input_to_cell_weight_ptr = input_to_cell_weights->data.int8;
655 input_to_cell_weight_scale = input_to_cell_weights->params.scale;
656 input_to_output_weight_ptr = input_to_output_weights->data.int8;
657 input_to_output_weight_scale = input_to_output_weights->params.scale;
658 recurrent_to_forget_weight_ptr = recurrent_to_forget_weights->data.int8;
659 recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
660 recurrent_to_cell_weight_ptr = recurrent_to_cell_weights->data.int8;
661 recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
662 recurrent_to_output_weight_ptr = recurrent_to_output_weights->data.int8;
663 recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
664 forget_gate_bias_ptr = forget_gate_bias->data.i32;
665 cell_gate_bias_ptr = cell_gate_bias->data.i32;
666 output_gate_bias_ptr = output_gate_bias->data.i32;
667 output_state_ptr = output_state->data.int8;
668 cell_ptr = cell_state->data.i16;
669 input_scale = input->params.scale;
670 input_zp = input->params.zero_point;
671 output_state_zp = output_state->params.zero_point;
672
673 std::vector<float> intermediate_scale;
674 for (int i = 0; i < 12; ++i) {
675 TfLiteTensor* intermediate =
676 &context->tensors[node->intermediates->data[i]];
677 auto* params = reinterpret_cast<TfLiteAffineQuantization*>(
678 intermediate->quantization.params);
679 intermediate_scale.push_back(params->scale->data[0]);
680 integer_lstm_param->intermediate_zp[i] = params->zero_point->data[0];
681 }
682
683 // Calculate effective scales.
684 if (!use_cifg) {
685 effective_input_to_input_scale =
686 input_to_input_weight_scale * input_scale / intermediate_scale[1];
687 effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
688 output_state_scale /
689 intermediate_scale[2];
690 }
691 effective_input_to_forget_scale =
692 input_to_forget_weight_scale * input_scale / intermediate_scale[4];
693 effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
694 output_state_scale /
695 intermediate_scale[5];
696
697 effective_input_to_cell_scale =
698 input_to_cell_weight_scale * input_scale / intermediate_scale[7];
699 effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
700 output_state_scale /
701 intermediate_scale[8];
702
703 effective_input_to_output_scale =
704 input_to_output_weight_scale * input_scale / intermediate_scale[10];
705 effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
706 output_state_scale /
707 intermediate_scale[11];
708 effective_proj_scale =
709 projection_weight_scale * std::pow(2, -15) / output_state_scale;
710
711 if (use_peephole) {
712 if (!use_cifg) {
713 effective_cell_to_input_scale =
714 std::pow(2, -15) * cell_to_input_weight_scale / intermediate_scale[0];
715 }
716 effective_cell_to_forget_scale =
717 std::pow(2, -15) * cell_to_forget_weight_scale / intermediate_scale[3];
718 effective_cell_to_output_scale =
719 std::pow(2, -15) * cell_to_output_weight_scale / intermediate_scale[9];
720 }
721
722 // Calculate effecgive scales.
723 QuantizeMultiplier(effective_input_to_input_scale,
724 &integer_lstm_param->effective_input_to_input_scale_a,
725 &integer_lstm_param->effective_input_to_input_scale_b);
726 QuantizeMultiplier(effective_recurrent_to_input_scale,
727 &integer_lstm_param->effective_recurrent_to_input_scale_a,
728 &integer_lstm_param->effective_recurrent_to_input_scale_b);
729 QuantizeMultiplier(effective_cell_to_input_scale,
730 &integer_lstm_param->effective_cell_to_input_scale_a,
731 &integer_lstm_param->effective_cell_to_input_scale_b);
732 QuantizeMultiplier(effective_input_to_forget_scale,
733 &integer_lstm_param->effective_input_to_forget_scale_a,
734 &integer_lstm_param->effective_input_to_forget_scale_b);
735 QuantizeMultiplier(
736 effective_recurrent_to_forget_scale,
737 &integer_lstm_param->effective_recurrent_to_forget_scale_a,
738 &integer_lstm_param->effective_recurrent_to_forget_scale_b);
739 QuantizeMultiplier(effective_cell_to_forget_scale,
740 &integer_lstm_param->effective_cell_to_forget_scale_a,
741 &integer_lstm_param->effective_cell_to_forget_scale_b);
742 QuantizeMultiplier(effective_input_to_cell_scale,
743 &integer_lstm_param->effective_input_to_cell_scale_a,
744 &integer_lstm_param->effective_input_to_cell_scale_b);
745 QuantizeMultiplier(effective_recurrent_to_cell_scale,
746 &integer_lstm_param->effective_recurrent_to_cell_scale_a,
747 &integer_lstm_param->effective_recurrent_to_cell_scale_b);
748 QuantizeMultiplier(effective_input_to_output_scale,
749 &integer_lstm_param->effective_input_to_output_scale_a,
750 &integer_lstm_param->effective_input_to_output_scale_b);
751 QuantizeMultiplier(
752 effective_recurrent_to_output_scale,
753 &integer_lstm_param->effective_recurrent_to_output_scale_a,
754 &integer_lstm_param->effective_recurrent_to_output_scale_b);
755 QuantizeMultiplier(effective_cell_to_output_scale,
756 &integer_lstm_param->effective_cell_to_output_scale_a,
757 &integer_lstm_param->effective_cell_to_output_scale_b);
758 QuantizeMultiplier(effective_proj_scale,
759 &integer_lstm_param->effective_proj_scale_a,
760 &integer_lstm_param->effective_proj_scale_b);
761 QuantizeMultiplier(layer_norm_input_scale,
762 &integer_lstm_param->layer_norm_input_scale_a,
763 &integer_lstm_param->layer_norm_input_scale_b);
764 QuantizeMultiplier(layer_norm_forget_scale,
765 &integer_lstm_param->layer_norm_forget_scale_a,
766 &integer_lstm_param->layer_norm_forget_scale_b);
767 QuantizeMultiplier(layer_norm_cell_scale,
768 &integer_lstm_param->layer_norm_cell_scale_a,
769 &integer_lstm_param->layer_norm_cell_scale_b);
770 QuantizeMultiplier(layer_norm_output_scale,
771 &integer_lstm_param->layer_norm_output_scale_a,
772 &integer_lstm_param->layer_norm_output_scale_b);
773
774 {
775 // Intermdiates in flatbuffer holds Wx, Wh and Wx+Wh.
776 // effective Wx, Wh is in effective_input/recurrent_to_<...>_scale
777 // So use intermediate_scale to hold scale from Wx and Wh to Wx+Wh
778 // 0: [1] -> [0]
779 // 1: [2] -> [0]
780 // and use intermdiate_zp as is.
781 const float s_1_0 = intermediate_scale[1] / intermediate_scale[0];
782 const float s_2_0 = intermediate_scale[2] / intermediate_scale[0];
783 const float s_4_3 = intermediate_scale[4] / intermediate_scale[3];
784 const float s_5_3 = intermediate_scale[5] / intermediate_scale[3];
785 const float s_7_6 = intermediate_scale[7] / intermediate_scale[6];
786 const float s_8_6 = intermediate_scale[8] / intermediate_scale[6];
787 const float s_10_9 = intermediate_scale[10] / intermediate_scale[9];
788 const float s_11_9 = intermediate_scale[11] / intermediate_scale[9];
789 QuantizeMultiplier(s_1_0, &integer_lstm_param->intermediate_scale_a[0],
790 &integer_lstm_param->intermediate_scale_b[0]);
791 QuantizeMultiplier(s_2_0, &integer_lstm_param->intermediate_scale_a[1],
792 &integer_lstm_param->intermediate_scale_b[1]);
793 QuantizeMultiplier(s_4_3, &integer_lstm_param->intermediate_scale_a[2],
794 &integer_lstm_param->intermediate_scale_b[2]);
795 QuantizeMultiplier(s_5_3, &integer_lstm_param->intermediate_scale_a[3],
796 &integer_lstm_param->intermediate_scale_b[3]);
797 QuantizeMultiplier(s_7_6, &integer_lstm_param->intermediate_scale_a[4],
798 &integer_lstm_param->intermediate_scale_b[4]);
799 QuantizeMultiplier(s_8_6, &integer_lstm_param->intermediate_scale_a[5],
800 &integer_lstm_param->intermediate_scale_b[5]);
801 QuantizeMultiplier(s_10_9, &integer_lstm_param->intermediate_scale_a[6],
802 &integer_lstm_param->intermediate_scale_b[6]);
803 QuantizeMultiplier(s_11_9, &integer_lstm_param->intermediate_scale_a[7],
804 &integer_lstm_param->intermediate_scale_b[7]);
805 }
806
807 // Calculate quantized clip for projection and cell.
808 const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
809 const float cell_clip = params->cell_clip;
810 const float proj_clip = params->proj_clip;
811
812 TfLiteTensor* output_tensor;
813 TF_LITE_ENSURE_OK(
814 context, GetOutputSafe(context, node, kOutputTensor, &output_tensor));
815
816 auto* cell_state_params = reinterpret_cast<TfLiteAffineQuantization*>(
817 cell_state->quantization.params);
818 auto* proj_params = reinterpret_cast<TfLiteAffineQuantization*>(
819 output_tensor->quantization.params);
820 TF_LITE_ENSURE_EQ(context, cell_state_params->scale->data[0], 1.0 / 32768);
821 if (cell_clip > 0.0 && cell_clip < 1.0) {
822 integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min(
823 std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
824 32767.0f));
825 } else {
826 integer_lstm_param->quantized_cell_clip = 0;
827 }
828 if (proj_clip > 0.0) {
829 integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min(
830 std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f));
831 } else {
832 integer_lstm_param->quantized_proj_clip = 0;
833 }
834 return kTfLiteOk;
835}
836
837} // namespace
838
839void* Init(TfLiteContext* context, const char* buffer, size_t length) {
840 auto* op_data = new OpData();
841 op_data->kernel_type = kTfLiteLSTMFullKernel;
842 // TODO(b/159066113): maybe just add the minimum required temp tensors?
843 context->AddTensors(context, kNumHybridTemporaryTensors,
844 &op_data->scratch_tensor_index);
845 // Tensors used for the sparse hybrid kernel.
846 context->AddTensors(context, /*tensors_to_add=*/kLedgersToAdd,
847 &op_data->ledger_index);
848 return op_data;
849}
850
851// LINT.IfChange
852// Check that input tensor dimensions matches with each other.
853TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
854 TfLiteNode* node, int n_input,
855 int n_output, int n_cell,
856 bool use_layer_norm, bool is_integer) {
857 const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
858
859 // Making sure clipping parameters have valid values.
860 // == 0 means no clipping
861 // > 0 means clipping
862 TF_LITE_ENSURE(context, params->cell_clip >= 0);
863 TF_LITE_ENSURE(context, params->proj_clip >= 0);
864
865 const TfLiteTensor* input_to_forget_weights;
866 TF_LITE_ENSURE_OK(context,
867 GetInputSafe(context, node, kInputToForgetWeightsTensor,
868 &input_to_forget_weights));
869 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
870 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
871 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
872 TF_LITE_ENSURE(context, (input_to_forget_weights->type == kTfLiteFloat32) ||
873 (input_to_forget_weights->type == kTfLiteUInt8) ||
874 (input_to_forget_weights->type == kTfLiteInt8));
875
876 const TfLiteTensor* input_to_input_weights =
877 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
878 const bool use_cifg = (input_to_input_weights == nullptr);
879 if (!use_cifg) {
880 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
881 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
882 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
883 TF_LITE_ENSURE_TYPES_EQ(context, input_to_input_weights->type,
884 input_to_forget_weights->type);
885 }
886
887 const TfLiteTensor* input_to_cell_weights;
888 TF_LITE_ENSURE_OK(context,
889 GetInputSafe(context, node, kInputToCellWeightsTensor,
890 &input_to_cell_weights));
891 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
892 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
893 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
894 TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type,
895 input_to_forget_weights->type);
896
897 const TfLiteTensor* recurrent_to_input_weights =
898 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
899 if (recurrent_to_input_weights != nullptr) {
900 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
901 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
902 n_cell);
903 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
904 n_output);
905 TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_input_weights->type,
906 input_to_forget_weights->type);
907 }
908
909 const TfLiteTensor* recurrent_to_forget_weights;
910 TF_LITE_ENSURE_OK(context,
911 GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
912 &recurrent_to_forget_weights));
913 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
914 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
915 n_cell);
916 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
917 n_output);
918 TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
919 input_to_forget_weights->type);
920
921 const TfLiteTensor* recurrent_to_cell_weights;
922 TF_LITE_ENSURE_OK(context,
923 GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
924 &recurrent_to_cell_weights));
925 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
926 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
927 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
928 n_output);
929 TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_cell_weights->type,
930 input_to_forget_weights->type);
931
932 // We make sure the input-gate's parameters are either both present (regular
933 // LSTM) or not at all (CIFG-LSTM).
934 const bool cifg_weights_all_or_none =
935 ((input_to_input_weights != nullptr) &&
936 (recurrent_to_input_weights != nullptr)) ||
937 ((input_to_input_weights == nullptr) &&
938 (recurrent_to_input_weights == nullptr));
939 TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
940
941 const TfLiteTensor* cell_to_input_weights =
942 GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
943 if (cell_to_input_weights) {
944 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
945 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
946 TF_LITE_ENSURE_TYPES_EQ(
947 context, cell_to_input_weights->type,
948 is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
949 }
950
951 const TfLiteTensor* cell_to_forget_weights =
952 GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
953 if (cell_to_forget_weights) {
954 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
955 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
956 TF_LITE_ENSURE_TYPES_EQ(
957 context, cell_to_forget_weights->type,
958 is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
959 }
960
961 const TfLiteTensor* cell_to_output_weights =
962 GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
963 if (cell_to_output_weights) {
964 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
965 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
966 TF_LITE_ENSURE_TYPES_EQ(
967 context, cell_to_output_weights->type,
968 is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
969 }
970
971 // Making sure the peephole weights are there all or none.
972 const bool peephole_weights_all_or_none =
973 ((cell_to_input_weights != nullptr || use_cifg) &&
974 (cell_to_forget_weights != nullptr) &&
975 (cell_to_output_weights != nullptr)) ||
976 ((cell_to_input_weights == nullptr) &&
977 (cell_to_forget_weights == nullptr) &&
978 (cell_to_output_weights == nullptr));
979 TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
980
981 // Make sure the input gate bias is present only when not a CIFG-LSTM.
982 const TfLiteTensor* input_gate_bias =
983 GetOptionalInputTensor(context, node, kInputGateBiasTensor);
984 if (use_cifg) {
985 TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
986 } else {
987 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
988 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
989 if (is_integer) {
990 TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32);
991 } else {
992 TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
993 }
994 }
995
996 const TfLiteTensor* forget_gate_bias;
997 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
998 &forget_gate_bias));
999 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
1000 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
1001 if (is_integer) {
1002 TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32);
1003 } else {
1004 TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
1005 }
1006
1007 const TfLiteTensor* cell_gate_bias;
1008 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
1009 &cell_gate_bias));
1010 TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
1011 TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
1012 if (is_integer) {
1013 TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32);
1014 } else {
1015 TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
1016 }
1017
1018 const TfLiteTensor* output_gate_bias;
1019 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
1020 &output_gate_bias));
1021 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
1022 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
1023 if (is_integer) {
1024 TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32);
1025 } else {
1026 TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
1027 }
1028
1029 const TfLiteTensor* projection_weights =
1030 GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
1031 if (projection_weights != nullptr) {
1032 TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
1033 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
1034 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
1035 TF_LITE_ENSURE_TYPES_EQ(context, projection_weights->type,
1036 input_to_forget_weights->type);
1037 }
1038
1039 const TfLiteTensor* projection_bias =
1040 GetOptionalInputTensor(context, node, kProjectionBiasTensor);
1041 if (projection_bias != nullptr) {
1042 TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
1043 TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
1044 if (is_integer) {
1045 TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32);
1046 } else {
1047 TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
1048 }
1049 }
1050
1051 // Making sure the projection tensors are consistent:
1052 // 1) If projection weight is not present, then projection bias should not be
1053 // present.
1054 // 2) If projection weight is present, then projection bias is optional.
1055 // TODO(ghodrat): make sure this is correct.
1056 const bool projection_tensors_consistent =
1057 ((projection_weights != nullptr) || (projection_bias == nullptr));
1058 TF_LITE_ENSURE(context, projection_tensors_consistent == true);
1059
1060 if (use_layer_norm) {
1061 const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
1062 context, node, kInputLayerNormCoefficientsTensor);
1063 if (use_cifg) {
1064 TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
1065 } else {
1066 TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
1067 TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
1068 TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
1069 n_cell);
1070 if (is_integer) {
1071 TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
1072 kTfLiteInt16);
1073 } else {
1074 TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
1075 kTfLiteFloat32);
1076 }
1077 }
1078
1079 const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
1080 context, node, kForgetLayerNormCoefficientsTensor);
1081 TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr);
1082 TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
1083 TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
1084 n_cell);
1085 if (is_integer) {
1086 TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
1087 kTfLiteInt16);
1088 } else {
1089 TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
1090 kTfLiteFloat32);
1091 }
1092
1093 const TfLiteTensor* cell_layer_norm_coefficients =
1094 GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor);
1095 TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr);
1096 TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
1097 TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
1098 n_cell);
1099 if (is_integer) {
1100 TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
1101 kTfLiteInt16);
1102 } else {
1103 TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
1104 kTfLiteFloat32);
1105 }
1106
1107 const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
1108 context, node, kOutputLayerNormCoefficientsTensor);
1109 TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr);
1110 TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
1111 TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
1112 n_cell);
1113 if (is_integer) {
1114 TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
1115 kTfLiteInt16);
1116 } else {
1117 TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
1118 kTfLiteFloat32);
1119 }
1120 }
1121
1122 return kTfLiteOk;
1123}
1124// LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
1125
1126TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
1127 TfLiteContext* context, int32_t zero_point,
1128 const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor,
1129 std::unique_ptr<int32_t[]>* output) {
1130 if (weight_tensor == nullptr) {
1131 return kTfLiteOk;
1132 }
1133
1134 const RuntimeShape& weight_shape = GetTensorShape(weight_tensor);
1135 TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2);
1136 const int row = weight_shape.Dims(0);
1137 const int col = weight_shape.Dims(1);
1138 output->reset(new int32_t[row]);
1139 if (bias_tensor == nullptr) {
1140 memset(output->get(), 0, row * sizeof(int32_t));
1141 } else {
1142 const int32_t* bias = GetTensorData<int32_t>(bias_tensor);
1143 memcpy(output->get(), bias, row * sizeof(int32_t));
1144 }
1145 if (zero_point != 0) {
1146 const int8_t* weight = GetTensorData<int8_t>(weight_tensor);
1147 tensor_utils::MatrixScalarMultiplyAccumulate(weight, zero_point, row, col,
1148 output->get());
1149 }
1150 return kTfLiteOk;
1151}
1152
1153TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
1154 OpData* op_data,
1155 TfLiteNode* node) {
1156 const TfLiteTensor* input;
1157 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
1158 const TfLiteTensor* output_state =
1159 GetVariableInput(context, node, kOutputStateTensor);
1160 TF_LITE_ENSURE(context, output_state != nullptr);
1161
1162 const int32_t input_zero_point = -input->params.zero_point;
1163 const int32_t output_state_zero_point = -output_state->params.zero_point;
1164
1165 const TfLiteTensor* input_to_input_weights =
1166 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
1167 const TfLiteTensor* input_to_forget_weights;
1168 TF_LITE_ENSURE_OK(context,
1169 GetInputSafe(context, node, kInputToForgetWeightsTensor,
1170 &input_to_forget_weights));
1171 const TfLiteTensor* input_to_cell_weights;
1172 TF_LITE_ENSURE_OK(context,
1173 GetInputSafe(context, node, kInputToCellWeightsTensor,
1174 &input_to_cell_weights));
1175 const TfLiteTensor* input_to_output_weights;
1176 TF_LITE_ENSURE_OK(context,
1177 GetInputSafe(context, node, kInputToOutputWeightsTensor,
1178 &input_to_output_weights));
1179
1180 const TfLiteTensor* recurrent_to_input_weights =
1181 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
1182 const TfLiteTensor* recurrent_to_forget_weights;
1183 TF_LITE_ENSURE_OK(context,
1184 GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
1185 &recurrent_to_forget_weights));
1186 const TfLiteTensor* recurrent_to_cell_weights;
1187 TF_LITE_ENSURE_OK(context,
1188 GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
1189 &recurrent_to_cell_weights));
1190 const TfLiteTensor* recurrent_to_output_weights;
1191 TF_LITE_ENSURE_OK(context,
1192 GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
1193 &recurrent_to_output_weights));
1194
1195 const TfLiteTensor* projection_weights =
1196 GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
1197 const TfLiteTensor* projection_bias =
1198 GetOptionalInputTensor(context, node, kProjectionBiasTensor);
1199
1200 lstm_eval::IntegerLstmParameter* integer_lstm_params =
1201 &op_data->integer_lstm_param;
1202
1203 const TfLiteTensor* intermediate =
1204 &context->tensors[node->intermediates->data[4]];
1205 const auto* params =
1206 static_cast<TfLiteAffineQuantization*>(intermediate->quantization.params);
1207 const int32_t hidden_zp = params->zero_point->data[0];
1208
1209 // Get bias and perform zero point calculation.
1210 // When there is layer normalization, the gate bias does not apply to matmul
1211 // directly:
1212 // y = ln(w * x + w * r + w * c) + b.
1213 const bool is_layer_norm = op_data->use_layer_norm;
1214
1215 // Forget gate.
1216 const TfLiteTensor* forget_gate_bias =
1217 is_layer_norm ? nullptr : GetInput(context, node, kForgetGateBiasTensor);
1218 TF_LITE_ENSURE_OK(
1219 context,
1220 PrecomputeZeroPointTimesWeightWithBias(
1221 context, input_zero_point, input_to_forget_weights, forget_gate_bias,
1222 &(integer_lstm_params->input_to_forget_effective_bias)));
1223
1224 TF_LITE_ENSURE_OK(
1225 context,
1226 PrecomputeZeroPointTimesWeightWithBias(
1227 context, output_state_zero_point, recurrent_to_forget_weights,
1228 nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias)));
1229
1230 // Modulation gate.
1231 const TfLiteTensor* cell_gate_bias =
1232 is_layer_norm ? nullptr : GetInput(context, node, kCellGateBiasTensor);
1233 TF_LITE_ENSURE_OK(
1234 context,
1235 PrecomputeZeroPointTimesWeightWithBias(
1236 context, input_zero_point, input_to_cell_weights, cell_gate_bias,
1237 &(integer_lstm_params->input_to_cell_effective_bias)));
1238 TF_LITE_ENSURE_OK(
1239 context,
1240 PrecomputeZeroPointTimesWeightWithBias(
1241 context, output_state_zero_point, recurrent_to_cell_weights, nullptr,
1242 &(integer_lstm_params->recurrent_to_cell_effective_bias)));
1243
1244 // Output gate.
1245 const TfLiteTensor* output_gate_bias =
1246 is_layer_norm ? nullptr : GetInput(context, node, kOutputGateBiasTensor);
1247 TF_LITE_ENSURE_OK(
1248 context,
1249 PrecomputeZeroPointTimesWeightWithBias(
1250 context, input_zero_point, input_to_output_weights, output_gate_bias,
1251 &(integer_lstm_params->input_to_output_effective_bias)));
1252
1253 TF_LITE_ENSURE_OK(
1254 context,
1255 PrecomputeZeroPointTimesWeightWithBias(
1256 context, output_state_zero_point, recurrent_to_output_weights,
1257 nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias)));
1258
1259 // Input gate. The calculation is only meaningful for non-cifg case.
1260 const TfLiteTensor* input_gate_bias =
1261 is_layer_norm ? nullptr : GetInput(context, node, kInputGateBiasTensor);
1262 TF_LITE_ENSURE_OK(
1263 context,
1264 PrecomputeZeroPointTimesWeightWithBias(
1265 context, input_zero_point, input_to_input_weights, input_gate_bias,
1266 &(integer_lstm_params->input_to_input_effective_bias)));
1267 TF_LITE_ENSURE_OK(
1268 context,
1269 PrecomputeZeroPointTimesWeightWithBias(
1270 context, output_state_zero_point, recurrent_to_input_weights, nullptr,
1271 &(integer_lstm_params->recurrent_to_input_effective_bias)));
1272
1273 // Projection bias. The calculation is only meaningful for with projection.
1274 TF_LITE_ENSURE_OK(context,
1275 PrecomputeZeroPointTimesWeightWithBias(
1276 context, hidden_zp, projection_weights, projection_bias,
1277 &(integer_lstm_params->projection_effective_bias)));
1278 return kTfLiteOk;
1279}
1280
1281// Resize the output, state tensors based on the sizes of the input tensors.
1282// Allocate a temporary scratch tensor. Also check that the sizes of the input
1283// tensors match each other.
1284// LINT.IfChange
1285TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
1286 OpData* op_data = static_cast<OpData*>(node->user_data);
1287
1288 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
1289 // Logic for determining regular lstm and layer norm lstm:
1290 // input_size, forget_gate_layer_norm_tensor (20) null? is_layer_norm?
1291 // 20, N/A, No.
1292 // 24, null, No.
1293 // 24, not null, Yes.
1294 // 20-inputs lstm are deprecated and is only kept here for backward
1295 // compatibility.
1296 if (node->inputs->size == 24) {
1297 const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
1298 context, node, kForgetLayerNormCoefficientsTensor);
1299 if (forget_layer_norm_coefficients == nullptr) {
1300 op_data->use_layer_norm = false;
1301 } else {
1302 op_data->use_layer_norm = true;
1303 }
1304 } else if (node->inputs->size == 20) {
1305 // This is deprecated and is only kept here for backward compatibility.
1306 op_data->use_layer_norm = false;
1307 } else {
1308 TF_LITE_KERNEL_LOG(
1309 context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
1310 node->inputs->size);
1311 return kTfLiteError;
1312 }
1313
1314 const bool use_layer_norm = op_data->use_layer_norm;
1315
1316 // Inferring batch size, number of outputs and number of cells from the
1317 // input tensors.
1318 const TfLiteTensor* input;
1319 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
1320 const bool is_integer = input->type == kTfLiteInt8;
1321 TF_LITE_ENSURE(context, input->dims->size > 1);
1322 const int n_batch = input->dims->data[0];
1323 const int n_input = input->dims->data[1];
1324
1325 const TfLiteTensor* input_to_output_weights;
1326 TF_LITE_ENSURE_OK(context,
1327 GetInputSafe(context, node, kInputToOutputWeightsTensor,
1328 &input_to_output_weights));
1329 const int n_cell = input_to_output_weights->dims->data[0];
1330 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
1331 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
1332
1333 const TfLiteTensor* recurrent_to_output_weights;
1334 TF_LITE_ENSURE_OK(context,
1335 GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
1336 &recurrent_to_output_weights));
1337 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
1338 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
1339 n_cell);
1340 const int n_output = recurrent_to_output_weights->dims->data[1];
1341
1342 // Check that input tensor dimensions matches with each other.
1343 TF_LITE_ENSURE_OK(
1344 context, CheckInputTensorDimensions(context, node, n_input, n_output,
1345 n_cell, use_layer_norm, is_integer));
1346
1347 // Get the pointer to output, output_state and cell_state tensors.
1348 TfLiteTensor* output;
1349 TF_LITE_ENSURE_OK(context,
1350 GetOutputSafe(context, node, kOutputTensor, &output));
1351
1352 TfLiteTensor* output_state =
1353 GetVariableInput(context, node, kOutputStateTensor);
1354 TF_LITE_ENSURE(context, output_state != nullptr);
1355 TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
1356 TF_LITE_ENSURE(context, cell_state != nullptr);
1357
1358 // Check the shape of input state tensors.
1359 // These tensor may be 1D or 2D. It's fine as long as the total size is
1360 // correct.
1361 TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output);
1362 TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
1363
1364 // Resize the output tensors.
1365 TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
1366 output_size->data[0] = n_batch;
1367 output_size->data[1] = n_output;
1368 TF_LITE_ENSURE_OK(context,
1369 context->ResizeTensor(context, output, output_size));
1370
1371 // The weights are of consistent type, so it suffices to check one.
1372 const bool is_hybrid_op = IsHybridOp(input, input_to_output_weights);
1373
1374 const bool is_sparse_op = (input_to_output_weights->sparsity != nullptr);
1375
1376 // The type of Integer LSTM.
1377 const int num_intermediate_tensors = node->intermediates->size;
1378 if (is_integer) {
1379 TF_LITE_ENSURE(context, num_intermediate_tensors == 5 ||
1380 num_intermediate_tensors == 12);
1381 }
1382 // We use number of intermediate tensors to distinguish the 8 bit matmul
1383 // output and the 16 bit matmul output version.
1384 const bool is_8x8_16 = num_intermediate_tensors == 5;
1385
1386 TfLiteIntArrayFree(node->temporaries);
1387 if (is_hybrid_op) {
1388 if (is_sparse_op) {
1389 node->temporaries =
1390 TfLiteIntArrayCreate(kNumHybridTemporaryTensors + kLedgersToAdd);
1391 } else {
1392 node->temporaries = TfLiteIntArrayCreate(kNumHybridTemporaryTensors);
1393 }
1394 } else if (is_integer) {
1395 if (is_8x8_16) {
1396 node->temporaries = TfLiteIntArrayCreate(6);
1397 } else {
1398 node->temporaries = TfLiteIntArrayCreate(8);
1399 }
1400 } else {
1401 node->temporaries = TfLiteIntArrayCreate(1);
1402 }
1403
1404 // Create a scratch buffer tensor for float case and hybrid case.
1405 // TODO(b/152066492): Create a is_float boolean and reorganize the temporary
1406 // buffer allocation logic.
1407 if (!is_integer) {
1408 node->temporaries->data[kScratchBuffer] =
1409 op_data->scratch_tensor_index + kScratchBuffer;
1410 TfLiteTensor* scratch_buffer;
1411 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
1412 &scratch_buffer));
1413 scratch_buffer->type = input->type;
1414 scratch_buffer->allocation_type = kTfLiteArenaRw;
1415
1416 const TfLiteTensor* input_to_input_weights =
1417 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
1418 const bool use_cifg = (input_to_input_weights == nullptr);
1419 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
1420 scratch_buffer_size->data[0] = n_batch;
1421 if (use_cifg) {
1422 // Reserving space for Cell, Forget, Output gates and scratch accumulation
1423 // buffer and an extra 16 bytes to avoid internal ruy copies.
1424 scratch_buffer_size->data[1] = n_cell * 4;
1425 } else {
1426 // Reserving space for Input, Cell, Forget, Output gates and scratch
1427 // accumulation buffer and an extra 16 bytes to avoid internal ruy copies.
1428 scratch_buffer_size->data[1] = n_cell * 5;
1429 }
1430 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
1431 scratch_buffer_size));
1432 }
1433
1434 if (is_hybrid_op) {
1435 if (!is_sparse_op) {
1436 op_data->compute_row_sums = true;
1437 }
1438 // Allocate temporary tensors to store quantized values of input,
1439 // output_state and cell_state tensors.
1440 node->temporaries->data[kInputQuantized] =
1441 op_data->scratch_tensor_index + kInputQuantized;
1442 TfLiteTensor* input_quantized;
1443 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
1444 &input_quantized));
1445 input_quantized->type = input_to_output_weights->type;
1446 input_quantized->allocation_type = kTfLiteArenaRw;
1447 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
1448 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
1449 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
1450 input_quantized_size));
1451 }
1452 node->temporaries->data[kOutputStateQuantized] =
1453 op_data->scratch_tensor_index + kOutputStateQuantized;
1454 TfLiteTensor* output_state_quantized;
1455 TF_LITE_ENSURE_OK(context,
1456 GetTemporarySafe(context, node, kOutputStateQuantized,
1457 &output_state_quantized));
1458 output_state_quantized->type = input_to_output_weights->type;
1459 output_state_quantized->allocation_type = kTfLiteArenaRw;
1460 if (!TfLiteIntArrayEqual(output_state_quantized->dims,
1461 output_state->dims)) {
1462 TfLiteIntArray* output_state_quantized_size =
1463 TfLiteIntArrayCopy(output_state->dims);
1464 TF_LITE_ENSURE_OK(context,
1465 context->ResizeTensor(context, output_state_quantized,
1466 output_state_quantized_size));
1467 }
1468 node->temporaries->data[kCellStateQuantized] =
1469 op_data->scratch_tensor_index + kCellStateQuantized;
1470 TfLiteTensor* cell_state_quantized;
1471 TF_LITE_ENSURE_OK(context,
1472 GetTemporarySafe(context, node, kCellStateQuantized,
1473 &cell_state_quantized));
1474 cell_state_quantized->type = input_to_output_weights->type;
1475 cell_state_quantized->allocation_type = kTfLiteArenaRw;
1476 if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
1477 TfLiteIntArray* cell_state_quantized_size =
1478 TfLiteIntArrayCopy(cell_state->dims);
1479 TF_LITE_ENSURE_OK(context,
1480 context->ResizeTensor(context, cell_state_quantized,
1481 cell_state_quantized_size));
1482 }
1483 // Allocate temporary tensors to store scaling factors and product scaling
1484 // factors. The latter is a convenience storage which allows to quantize
1485 // a vector once (which produces the scaling factors) and multiply it with
1486 // different matrices (which requires multiplying the scaling factors with
1487 // the scaling factor of the matrix).
1488 node->temporaries->data[kInputScalingFactors] =
1489 op_data->scratch_tensor_index + kInputScalingFactors;
1490 TfLiteTensor* input_sf;
1491 TF_LITE_ENSURE_OK(
1492 context,
1493 GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
1494 input_sf->type = kTfLiteFloat32;
1495 input_sf->allocation_type = kTfLiteArenaRw;
1496 int scaling_dims[1] = {n_batch};
1497 if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
1498 TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
1499 input_sf_size->data[0] = n_batch;
1500 TF_LITE_ENSURE_OK(
1501 context, context->ResizeTensor(context, input_sf, input_sf_size));
1502 }
1503 node->temporaries->data[kOutputStateScalingFactors] =
1504 op_data->scratch_tensor_index + kOutputStateScalingFactors;
1505 TfLiteTensor* output_state_sf;
1506 TF_LITE_ENSURE_OK(
1507 context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
1508 &output_state_sf));
1509 output_state_sf->type = kTfLiteFloat32;
1510 output_state_sf->allocation_type = kTfLiteArenaRw;
1511 if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
1512 TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
1513 output_state_sf_size->data[0] = n_batch;
1514 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
1515 output_state_sf_size));
1516 }
1517 node->temporaries->data[kProductScalingFactors] =
1518 op_data->scratch_tensor_index + kProductScalingFactors;
1519 TfLiteTensor* prod_scaling_factors;
1520 TF_LITE_ENSURE_OK(context,
1521 GetTemporarySafe(context, node, kProductScalingFactors,
1522 &prod_scaling_factors));
1523 prod_scaling_factors->type = kTfLiteFloat32;
1524 prod_scaling_factors->allocation_type = kTfLiteArenaRw;
1525 if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
1526 scaling_dims)) {
1527 TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
1528 prod_scaling_factors_size->data[0] = n_batch;
1529 TF_LITE_ENSURE_OK(context,
1530 context->ResizeTensor(context, prod_scaling_factors,
1531 prod_scaling_factors_size));
1532 }
1533
1534 // Allocate a temporary tensor to store the recovered cell weights. Since
1535 // this is used for diagonal matrices, only need to store n_cell values.
1536 node->temporaries->data[kRecoveredCellWeights] =
1537 op_data->scratch_tensor_index + kRecoveredCellWeights;
1538 TfLiteTensor* recovered_cell_weights;
1539 TF_LITE_ENSURE_OK(context,
1540 GetTemporarySafe(context, node, kRecoveredCellWeights,
1541 &recovered_cell_weights));
1542 recovered_cell_weights->type = kTfLiteFloat32;
1543 recovered_cell_weights->allocation_type = kTfLiteArenaRw;
1544 int recovered_cell_dims[1] = {n_cell};
1545 if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
1546 recovered_cell_dims)) {
1547 TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
1548 recovered_cell_weights_size->data[0] = n_cell;
1549 TF_LITE_ENSURE_OK(context,
1550 context->ResizeTensor(context, recovered_cell_weights,
1551 recovered_cell_weights_size));
1552 }
1553 // Allocate a temporary tensor to store accumulate values for matrix
1554 // multiplication before multiplication by scaling factor
1555 node->temporaries->data[kAccumScratch] =
1556 op_data->scratch_tensor_index + kAccumScratch;
1557 TfLiteTensor* accum_scratch;
1558 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
1559 &accum_scratch));
1560 accum_scratch->type = kTfLiteInt32;
1561 accum_scratch->allocation_type = kTfLiteArenaRw;
1562 int accum_scratch_dims[2] = {n_cell, n_batch};
1563 if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
1564 accum_scratch_dims)) {
1565 TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
1566 accum_size->data[0] = n_cell;
1567 accum_size->data[1] = n_batch;
1568 TF_LITE_ENSURE_OK(
1569 context, context->ResizeTensor(context, accum_scratch, accum_size));
1570 }
1571 node->temporaries->data[kInputZeroPoints] =
1572 op_data->scratch_tensor_index + kInputZeroPoints;
1573 TfLiteTensor* input_zp;
1574 TF_LITE_ENSURE_OK(
1575 context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
1576 input_zp->type = kTfLiteFloat32;
1577 input_zp->allocation_type = kTfLiteArenaRw;
1578 if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
1579 TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
1580 input_zp_size->data[0] = n_batch;
1581 TF_LITE_ENSURE_OK(
1582 context, context->ResizeTensor(context, input_zp, input_zp_size));
1583 }
1584 node->temporaries->data[kOutputStateZeroPoints] =
1585 op_data->scratch_tensor_index + kOutputStateZeroPoints;
1586 TfLiteTensor* output_state_zp;
1587 TF_LITE_ENSURE_OK(context,
1588 GetTemporarySafe(context, node, kOutputStateZeroPoints,
1589 &output_state_zp));
1590 output_state_zp->type = kTfLiteFloat32;
1591 output_state_zp->allocation_type = kTfLiteArenaRw;
1592 if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
1593 TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
1594 output_state_zp_size->data[0] = n_batch;
1595 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
1596 output_state_zp_size));
1597 }
1598
1599 node->temporaries->data[kRowSums] =
1600 op_data->scratch_tensor_index + kRowSums;
1601 const TfLiteTensor* input_to_input_weights =
1602 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
1603 const bool use_cifg = (input_to_input_weights == nullptr);
1604 int row_sums_rows = use_cifg ? 6 : 8;
1605 const TfLiteTensor* projection_weights =
1606 GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
1607 if (projection_weights != nullptr) {
1608 row_sums_rows += ceil(static_cast<float>(n_output) / n_cell);
1609 }
1610
1611 TfLiteTensor* row_sums;
1612 TF_LITE_ENSURE_OK(context,
1613 GetTemporarySafe(context, node, kRowSums, &row_sums));
1614 row_sums->type = kTfLiteInt32;
1615 row_sums->name = "Lstm_row_sums";
1616 row_sums->allocation_type = kTfLiteArenaRwPersistent;
1617 const int row_sums_dims[2] = {row_sums_rows, n_cell};
1618 if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
1619 TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
1620 row_sums_size->data[0] = row_sums_dims[0];
1621 row_sums_size->data[1] = row_sums_dims[1];
1622 TF_LITE_ENSURE_OK(
1623 context, context->ResizeTensor(context, row_sums, row_sums_size));
1624 }
1625
1626 if (is_sparse_op) {
1627 op_data->ledger_initialized = false;
1628 int offset = kNumHybridTemporaryTensors;
1629 {
1630 node->temporaries->data[offset + kInputToInputWeightsLedgerOffset] =
1631 op_data->ledger_index + kInputToInputWeightsLedgerOffset;
1632 const TfLiteTensor* input_to_input_weights =
1633 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
1634 TfLiteTensor* input_to_input_weights_ledger =
1635 &context->tensors[op_data->ledger_index +
1636 kInputToInputWeightsLedgerOffset];
1637 auto status = make_ledger(input_to_input_weights == nullptr
1638 ? nullptr
1639 : input_to_input_weights->sparsity,
1640 context, input_to_input_weights_ledger);
1641 if (status != kTfLiteOk) return status;
1642 }
1643 {
1644 node->temporaries->data[offset + kInputToForgetWeightsLedgerOffset] =
1645 op_data->ledger_index + kInputToForgetWeightsLedgerOffset;
1646 const TfLiteTensor* input_to_forget_weights =
1647 GetInput(context, node, kInputToForgetWeightsTensor);
1648 TfLiteTensor* input_to_forget_weights_ledger =
1649 &context->tensors[op_data->ledger_index +
1650 kInputToForgetWeightsLedgerOffset];
1651 auto status = make_ledger(input_to_forget_weights->sparsity, context,
1652 input_to_forget_weights_ledger);
1653 if (status != kTfLiteOk) return status;
1654 }
1655 {
1656 node->temporaries->data[offset + kInputToCellWeightsLedgerOffset] =
1657 op_data->ledger_index + kInputToCellWeightsLedgerOffset;
1658 const TfLiteTensor* input_to_cell_weights =
1659 GetInput(context, node, kInputToCellWeightsTensor);
1660 TfLiteTensor* input_to_cell_weights_ledger =
1661 &context->tensors[op_data->ledger_index +
1662 kInputToCellWeightsLedgerOffset];
1663 auto status = make_ledger(input_to_cell_weights->sparsity, context,
1664 input_to_cell_weights_ledger);
1665 if (status != kTfLiteOk) return status;
1666 }
1667 {
1668 node->temporaries->data[offset + kInputToOutputWeightsLedgerOffset] =
1669 op_data->ledger_index + kInputToOutputWeightsLedgerOffset;
1670 const TfLiteTensor* input_to_output_weights =
1671 GetInput(context, node, kInputToOutputWeightsTensor);
1672 TfLiteTensor* input_to_output_weights_ledger =
1673 &context->tensors[op_data->ledger_index +
1674 kInputToOutputWeightsLedgerOffset];
1675 auto status = make_ledger(input_to_output_weights->sparsity, context,
1676 input_to_output_weights_ledger);
1677 if (status != kTfLiteOk) return status;
1678 }
1679 {
1680 node->temporaries->data[offset + kRecurrentToInputWeightsLedgerOffset] =
1681 op_data->ledger_index + kRecurrentToInputWeightsLedgerOffset;
1682 const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
1683 context, node, kRecurrentToInputWeightsTensor);
1684 TfLiteTensor* recurrent_to_input_weights_ledger =
1685 &context->tensors[op_data->ledger_index +
1686 kRecurrentToInputWeightsLedgerOffset];
1687 auto status = make_ledger(recurrent_to_input_weights == nullptr
1688 ? nullptr
1689 : recurrent_to_input_weights->sparsity,
1690 context, recurrent_to_input_weights_ledger);
1691 if (status != kTfLiteOk) return status;
1692 }
1693 {
1694 node->temporaries
1695 ->data[offset + kRecurrentToForgetWeightsLedgerOffset] =
1696 op_data->ledger_index + kRecurrentToForgetWeightsLedgerOffset;
1697 const TfLiteTensor* recurrent_to_forget_weights =
1698 GetInput(context, node, kRecurrentToForgetWeightsTensor);
1699 TfLiteTensor* recurrent_to_forget_weights_ledger =
1700 &context->tensors[op_data->ledger_index +
1701 kRecurrentToForgetWeightsLedgerOffset];
1702 auto status = make_ledger(recurrent_to_forget_weights->sparsity,
1703 context, recurrent_to_forget_weights_ledger);
1704 if (status != kTfLiteOk) return status;
1705 }
1706 {
1707 node->temporaries->data[offset + kRecurrentToCellWeightsLedgerOffset] =
1708 op_data->ledger_index + kRecurrentToCellWeightsLedgerOffset;
1709 const TfLiteTensor* recurrent_to_cell_weights =
1710 GetInput(context, node, kRecurrentToCellWeightsTensor);
1711 TfLiteTensor* recurrent_to_cell_weights_ledger =
1712 &context->tensors[op_data->ledger_index +
1713 kRecurrentToCellWeightsLedgerOffset];
1714 auto status = make_ledger(recurrent_to_cell_weights->sparsity, context,
1715 recurrent_to_cell_weights_ledger);
1716 if (status != kTfLiteOk) return status;
1717 }
1718 {
1719 node->temporaries
1720 ->data[offset + kRecurrentToOutputWeightsLedgerOffset] =
1721 op_data->ledger_index + kRecurrentToOutputWeightsLedgerOffset;
1722 const TfLiteTensor* recurrent_to_output_weights =
1723 GetInput(context, node, kRecurrentToOutputWeightsTensor);
1724 TfLiteTensor* recurrent_to_output_weights_ledger =
1725 &context->tensors[op_data->ledger_index +
1726 kRecurrentToOutputWeightsLedgerOffset];
1727 auto status = make_ledger(recurrent_to_output_weights->sparsity,
1728 context, recurrent_to_output_weights_ledger);
1729 if (status != kTfLiteOk) return status;
1730 }
1731 {
1732 node->temporaries->data[offset + kProjectionWeightsLedgerOffset] =
1733 op_data->ledger_index + kProjectionWeightsLedgerOffset;
1734 const TfLiteTensor* projection_weights =
1735 GetInput(context, node, kProjectionWeightsTensor);
1736 TfLiteTensor* projection_weights_ledger =
1737 &context->tensors[op_data->ledger_index +
1738 kProjectionWeightsLedgerOffset];
1739 auto status = make_ledger(projection_weights->sparsity, context,
1740 projection_weights_ledger);
1741 if (status != kTfLiteOk) return status;
1742 }
1743 }
1744 }
1745
1746 if (is_integer) {
1747 if (is_8x8_16) {
1748 // Integer LSTM prepare function for 8x8->16.
1749 // This code path needs 5 intermediate tensors per Op.
1750 // Populate quantization parameters.
1751 PopulateQuantizedLstmParams8x8_16(context, node,
1752 &op_data->integer_lstm_param);
1753
1754 // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
1755 // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
1756 // buffer with size n_batch * n_cell.
1757 //
1758 // Handle cifg case as well, which might save one buffer.
1759 for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
1760 node->temporaries->data[scratch_index] =
1761 op_data->scratch_tensor_index + scratch_index;
1762 TfLiteTensor* scratch_tensor;
1763 TF_LITE_ENSURE_OK(
1764 context,
1765 GetTemporarySafe(context, node, scratch_index, &scratch_tensor));
1766 scratch_tensor->type = kTfLiteInt16;
1767 if (scratch_index == 4) {
1768 scratch_tensor->type = kTfLiteInt8;
1769 } else if (scratch_index == 5) {
1770 scratch_tensor->type = kTfLiteInt32;
1771 }
1772 scratch_tensor->allocation_type = kTfLiteArenaRw;
1773 const int scratch_dimension[2] = {n_batch, n_cell};
1774 if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2,
1775 scratch_dimension)) {
1776 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
1777 scratch_buffer_size->data[0] = n_batch;
1778 scratch_buffer_size->data[1] = n_cell;
1779 TF_LITE_ENSURE_OK(context,
1780 context->ResizeTensor(context, scratch_tensor,
1781 scratch_buffer_size));
1782 }
1783 }
1784
1785 // Populate precomputed zp * weight.
1786 TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias(
1787 context, op_data, node));
1788 } else {
1789 // Integer LSTM prepare function for 8x8->8.
1790 // This code path needs 12 intermediate tensors per Op.
1791 PopulateQuantizedLstmParams8x8_8(context, node,
1792 &op_data->integer_lstm_param);
1793
1794 // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
1795 // and 2 8bit buffer with size n_batch * n_cell.
1796 //
1797 // Handle cifg case as well, which might save one buffer.
1798 for (int scratch_index = 0; scratch_index < 8; ++scratch_index) {
1799 node->temporaries->data[scratch_index] =
1800 op_data->scratch_tensor_index + scratch_index;
1801 TfLiteTensor* scratch_tensor;
1802 TF_LITE_ENSURE_OK(
1803 context,
1804 GetTemporarySafe(context, node, scratch_index, &scratch_tensor));
1805 if (scratch_index == 0 || scratch_index == 1) {
1806 scratch_tensor->type = kTfLiteInt8;
1807 } else {
1808 scratch_tensor->type = kTfLiteInt16;
1809 }
1810 scratch_tensor->allocation_type = kTfLiteArenaRw;
1811 const int scratch_dimension[2] = {n_batch, n_cell};
1812 if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2,
1813 scratch_dimension)) {
1814 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
1815 scratch_buffer_size->data[0] = n_batch;
1816 scratch_buffer_size->data[1] = n_cell;
1817 TF_LITE_ENSURE_OK(context,
1818 context->ResizeTensor(context, scratch_tensor,
1819 scratch_buffer_size));
1820 }
1821 }
1822 }
1823 }
1824 return kTfLiteOk;
1825}
1826// LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
1827
1828// LINT.IfChange
1829TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
1830 const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data);
1831 OpData* op_data = static_cast<OpData*>(node->user_data);
1832
1833 const TfLiteTensor* input;
1834 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
1835
1836 const TfLiteTensor* input_to_input_weights =
1837 GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
1838 const TfLiteTensor* input_to_forget_weights;
1839 TF_LITE_ENSURE_OK(context,
1840 GetInputSafe(context, node, kInputToForgetWeightsTensor,
1841 &input_to_forget_weights));
1842 const TfLiteTensor* input_to_cell_weights;
1843 TF_LITE_ENSURE_OK(context,
1844 GetInputSafe(context, node, kInputToCellWeightsTensor,
1845 &input_to_cell_weights));
1846 const TfLiteTensor* input_to_output_weights;
1847 TF_LITE_ENSURE_OK(context,
1848 GetInputSafe(context, node, kInputToOutputWeightsTensor,
1849 &input_to_output_weights));
1850
1851 const TfLiteTensor* recurrent_to_input_weights =
1852 GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
1853 const TfLiteTensor* recurrent_to_forget_weights;
1854 TF_LITE_ENSURE_OK(context,
1855 GetInputSafe(context, node, kRecurrentToForgetWeightsTensor,
1856 &recurrent_to_forget_weights));
1857 const TfLiteTensor* recurrent_to_cell_weights;
1858 TF_LITE_ENSURE_OK(context,
1859 GetInputSafe(context, node, kRecurrentToCellWeightsTensor,
1860 &recurrent_to_cell_weights));
1861 const TfLiteTensor* recurrent_to_output_weights;
1862 TF_LITE_ENSURE_OK(context,
1863 GetInputSafe(context, node, kRecurrentToOutputWeightsTensor,
1864 &recurrent_to_output_weights));
1865
1866 const TfLiteTensor* cell_to_input_weights =
1867 GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
1868 const TfLiteTensor* cell_to_forget_weights =
1869 GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
1870 const TfLiteTensor* cell_to_output_weights =
1871 GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
1872
1873 const TfLiteTensor* input_layer_norm_coefficients =
1874 GetOptionalInputTensor(context, node, kInputLayerNormCoefficientsTensor);
1875 const TfLiteTensor* forget_layer_norm_coefficients =
1876 GetOptionalInputTensor(context, node, kForgetLayerNormCoefficientsTensor);
1877 const TfLiteTensor* cell_layer_norm_coefficients =
1878 GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor);
1879 const TfLiteTensor* output_layer_norm_coefficients =
1880 GetOptionalInputTensor(context, node, kOutputLayerNormCoefficientsTensor);
1881
1882 const TfLiteTensor* input_gate_bias =
1883 GetOptionalInputTensor(context, node, kInputGateBiasTensor);
1884 const TfLiteTensor* forget_gate_bias;
1885 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor,
1886 &forget_gate_bias));
1887 const TfLiteTensor* cell_gate_bias;
1888 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor,
1889 &cell_gate_bias));
1890 const TfLiteTensor* output_gate_bias;
1891 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor,
1892 &output_gate_bias));
1893
1894 const TfLiteTensor* projection_weights =
1895 GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
1896 const TfLiteTensor* projection_bias =
1897 GetOptionalInputTensor(context, node, kProjectionBiasTensor);
1898
1899 TfLiteTensor* output_state =
1900 GetVariableInput(context, node, kOutputStateTensor);
1901 TFLITE_DCHECK(output_state != nullptr);
1902 TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor);
1903 TFLITE_DCHECK(cell_state != nullptr);
1904
1905 TfLiteTensor* output;
1906 TF_LITE_ENSURE_OK(context,
1907 GetOutputSafe(context, node, kOutputTensor, &output));
1908
1909 switch (input_to_output_weights->type) {
1910 case kTfLiteFloat32: {
1911 // Index the scratch buffers pointers to the global scratch buffer.
1912 TfLiteTensor* scratch_buffer;
1913 TF_LITE_ENSURE_OK(context,
1914 GetTemporarySafe(context, node, 0, &scratch_buffer));
1915 return lstm_eval::EvalFloat(
1916 input, input_to_input_weights, input_to_forget_weights,
1917 input_to_cell_weights, input_to_output_weights,
1918 recurrent_to_input_weights, recurrent_to_forget_weights,
1919 recurrent_to_cell_weights, recurrent_to_output_weights,
1920 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
1921 input_layer_norm_coefficients, forget_layer_norm_coefficients,
1922 cell_layer_norm_coefficients, output_layer_norm_coefficients,
1923 /*aux_input=*/nullptr,
1924 /*aux_input_to_input_weights=*/nullptr,
1925 /*aux_input_to_forget_weights=*/nullptr,
1926 /*aux_input_to_cell_weights=*/nullptr,
1927 /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
1928 forget_gate_bias, cell_gate_bias, output_gate_bias,
1929 projection_weights, projection_bias, params,
1930 /*forward_sequence=*/true,
1931 /*time_major=*/true,
1932 /*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
1933 CpuBackendContext::GetFromContext(context));
1934 }
1935 case kTfLiteUInt8:
1936 case kTfLiteInt8: {
1937 const bool is_hybrid = (input->type == kTfLiteFloat32);
1938 const bool is_sparse = input_to_output_weights->sparsity != nullptr;
1939 if (is_hybrid) {
1940 TfLiteTensor* row_sums;
1941 TF_LITE_ENSURE_OK(context,
1942 GetTemporarySafe(context, node, kRowSums, &row_sums));
1943 const int row_sums_size = row_sums->dims->data[0];
1944 if (is_sparse) {
1945 TfLiteTensor* input_to_input_weights_ledger =
1946 &context->tensors[op_data->ledger_index +
1947 kInputToInputWeightsLedgerOffset];
1948 TfLiteTensor* input_to_forget_weights_ledger =
1949 &context->tensors[op_data->ledger_index +
1950 kInputToForgetWeightsLedgerOffset];
1951 TfLiteTensor* input_to_cell_weights_ledger =
1952 &context->tensors[op_data->ledger_index +
1953 kInputToCellWeightsLedgerOffset];
1954 TfLiteTensor* input_to_output_weights_ledger =
1955 &context->tensors[op_data->ledger_index +
1956 kInputToOutputWeightsLedgerOffset];
1957 TfLiteTensor* recurrent_to_input_weights_ledger =
1958 &context->tensors[op_data->ledger_index +
1959 kRecurrentToInputWeightsLedgerOffset];
1960 TfLiteTensor* recurrent_to_forget_weights_ledger =
1961 &context->tensors[op_data->ledger_index +
1962 kRecurrentToForgetWeightsLedgerOffset];
1963 TfLiteTensor* recurrent_to_cell_weights_ledger =
1964 &context->tensors[op_data->ledger_index +
1965 kRecurrentToCellWeightsLedgerOffset];
1966 TfLiteTensor* recurrent_to_output_weights_ledger =
1967 &context->tensors[op_data->ledger_index +
1968 kRecurrentToOutputWeightsLedgerOffset];
1969 TfLiteTensor* projection_weights_ledger =
1970 &context->tensors[op_data->ledger_index +
1971 kProjectionWeightsLedgerOffset];
1972 if (!op_data->ledger_initialized) {
1973 copy_ledger(input_to_input_weights == nullptr
1974 ? nullptr
1975 : input_to_input_weights->sparsity,
1976 input_to_input_weights_ledger);
1977 copy_ledger(input_to_forget_weights->sparsity,
1978 input_to_forget_weights_ledger);
1979 copy_ledger(input_to_cell_weights->sparsity,
1980 input_to_cell_weights_ledger);
1981 copy_ledger(input_to_output_weights->sparsity,
1982 input_to_output_weights_ledger);
1983 copy_ledger(recurrent_to_input_weights == nullptr
1984 ? nullptr
1985 : recurrent_to_input_weights->sparsity,
1986 recurrent_to_input_weights_ledger);
1987 copy_ledger(recurrent_to_forget_weights->sparsity,
1988 recurrent_to_forget_weights_ledger);
1989 copy_ledger(recurrent_to_cell_weights->sparsity,
1990 recurrent_to_cell_weights_ledger);
1991 copy_ledger(recurrent_to_output_weights->sparsity,
1992 recurrent_to_output_weights_ledger);
1993 copy_ledger(projection_weights->sparsity,
1994 projection_weights_ledger);
1995 op_data->ledger_initialized = true;
1996 }
1997 return lstm_eval::EvalHybrid(
1998 input, input_to_input_weights, input_to_input_weights_ledger,
1999 input_to_forget_weights, input_to_forget_weights_ledger,
2000 input_to_cell_weights, input_to_cell_weights_ledger,
2001 input_to_output_weights, input_to_output_weights_ledger,
2002 recurrent_to_input_weights, recurrent_to_input_weights_ledger,
2003 recurrent_to_forget_weights, recurrent_to_forget_weights_ledger,
2004 recurrent_to_cell_weights, recurrent_to_cell_weights_ledger,
2005 recurrent_to_output_weights, recurrent_to_output_weights_ledger,
2006 cell_to_input_weights, cell_to_forget_weights,
2007 cell_to_output_weights, input_layer_norm_coefficients,
2008 forget_layer_norm_coefficients, cell_layer_norm_coefficients,
2009 output_layer_norm_coefficients,
2010 /*aux_input=*/nullptr,
2011 /*aux_input_to_input_weights=*/nullptr,
2012 /*aux_input_to_forget_weights=*/nullptr,
2013 /*aux_input_to_cell_weights=*/nullptr,
2014 /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
2015 forget_gate_bias, cell_gate_bias, output_gate_bias,
2016 projection_weights, projection_weights_ledger, projection_bias,
2017 params,
2018 /*forward_sequence=*/true, /*time_major=*/true,
2019 /*output_offset=*/0, GetTemporary(context, node, kScratchBuffer),
2020 GetTemporary(context, node, kInputScalingFactors),
2021 /*aux_input_sf=*/nullptr,
2022 GetTemporary(context, node, kOutputStateScalingFactors),
2023 GetTemporary(context, node, kProductScalingFactors),
2024 GetTemporary(context, node, kRecoveredCellWeights),
2025 GetTemporary(context, node, kInputQuantized),
2026 /*aux_input_quantized=*/nullptr,
2027 GetTemporary(context, node, kOutputStateQuantized),
2028 GetTemporary(context, node, kCellStateQuantized), output_state,
2029 cell_state, GetTemporary(context, node, kAccumScratch), output,
2030 GetTemporary(context, node, kInputZeroPoints),
2031 /*aux_input_zp=*/nullptr,
2032 GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
2033 row_sums_size, &op_data->compute_row_sums,
2034 CpuBackendContext::GetFromContext(context));
2035 }
2036 return lstm_eval::EvalHybrid(
2037 input, input_to_input_weights,
2038 /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights,
2039 /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights,
2040 /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights,
2041 /*input_to_output_weights_ledger*/ nullptr,
2042 recurrent_to_input_weights,
2043 /*recurrent_to_input_weights_ledger*/ nullptr,
2044 recurrent_to_forget_weights,
2045 /*recurrent_to_forget_weights_ledger*/ nullptr,
2046 recurrent_to_cell_weights,
2047 /*recurrent_to_cell_weights_ledger*/ nullptr,
2048 recurrent_to_output_weights,
2049 /*recurrent_to_output_weights_ledger*/ nullptr,
2050 cell_to_input_weights, cell_to_forget_weights,
2051 cell_to_output_weights, input_layer_norm_coefficients,
2052 forget_layer_norm_coefficients, cell_layer_norm_coefficients,
2053 output_layer_norm_coefficients, /*aux_input=*/nullptr,
2054 /*aux_input_to_input_weights=*/nullptr,
2055 /*aux_input_to_forget_weights=*/nullptr,
2056 /*aux_input_to_cell_weights=*/nullptr,
2057 /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
2058 forget_gate_bias, cell_gate_bias, output_gate_bias,
2059 projection_weights, /*projection_weights_ledger*/ nullptr,
2060 projection_bias, params,
2061 /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0,
2062 GetTemporary(context, node, kScratchBuffer),
2063 GetTemporary(context, node, kInputScalingFactors),
2064 /*aux_input_sf=*/nullptr,
2065 GetTemporary(context, node, kOutputStateScalingFactors),
2066 GetTemporary(context, node, kProductScalingFactors),
2067 GetTemporary(context, node, kRecoveredCellWeights),
2068 GetTemporary(context, node, kInputQuantized),
2069 /*aux_input_quantized=*/nullptr,
2070 GetTemporary(context, node, kOutputStateQuantized),
2071 GetTemporary(context, node, kCellStateQuantized), output_state,
2072 cell_state, GetTemporary(context, node, kAccumScratch), output,
2073 GetTemporary(context, node, kInputZeroPoints),
2074 /*aux_input_zp=*/nullptr,
2075 GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
2076 row_sums_size, &op_data->compute_row_sums,
2077 CpuBackendContext::GetFromContext(context));
2078 }
2079 const int num_intermediate_tensors = node->intermediates->size;
2080 TfLiteTensor* scratch0;
2081 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 0, &scratch0));
2082 TfLiteTensor* scratch1;
2083 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 1, &scratch1));
2084 TfLiteTensor* scratch2;
2085 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 2, &scratch2));
2086 TfLiteTensor* scratch3;
2087 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 3, &scratch3));
2088 TfLiteTensor* scratch4;
2089 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 4, &scratch4));
2090 TfLiteTensor* scratch5;
2091 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &scratch5));
2092 if (num_intermediate_tensors == 5) {
2093 return lstm_eval::EvalInteger8x8_16(
2094 input, input_to_input_weights, input_to_forget_weights,
2095 input_to_cell_weights, input_to_output_weights,
2096 recurrent_to_input_weights, recurrent_to_forget_weights,
2097 recurrent_to_cell_weights, recurrent_to_output_weights,
2098 cell_to_input_weights, cell_to_forget_weights,
2099 cell_to_output_weights, input_layer_norm_coefficients,
2100 forget_layer_norm_coefficients, cell_layer_norm_coefficients,
2101 output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
2102 cell_gate_bias, output_gate_bias, projection_weights,
2103 projection_bias, params, /*forward_sequence=*/true,
2104 /*time_major=*/true, &op_data->integer_lstm_param, output_state,
2105 cell_state, output, scratch0, scratch1, scratch2, scratch3,
2106 scratch4, scratch5, CpuBackendContext::GetFromContext(context));
2107 }
2108 TfLiteTensor* scratch6;
2109 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 6, &scratch6));
2110 TfLiteTensor* scratch7;
2111 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 7, &scratch7));
2112 return lstm_eval::EvalInteger8x8_8(
2113 input, input_to_input_weights, input_to_forget_weights,
2114 input_to_cell_weights, input_to_output_weights,
2115 recurrent_to_input_weights, recurrent_to_forget_weights,
2116 recurrent_to_cell_weights, recurrent_to_output_weights,
2117 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
2118 input_layer_norm_coefficients, forget_layer_norm_coefficients,
2119 cell_layer_norm_coefficients, output_layer_norm_coefficients,
2120 input_gate_bias, forget_gate_bias, cell_gate_bias, output_gate_bias,
2121 projection_weights, projection_bias, params, output_state, cell_state,
2122 output, &op_data->integer_lstm_param, scratch0, scratch1, scratch2,
2123 scratch3, scratch4, scratch5, scratch6, scratch7);
2124 }
2125 default:
2126 TF_LITE_KERNEL_LOG(context, "Type %d is not currently supported.",
2127 input_to_output_weights->type);
2128 return kTfLiteError;
2129 }
2130}
2131// LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
2132
2133} // namespace full
2134
2135// For basic kernel (5-inputs).
2136namespace basic {
2137
2138enum InputTensor {
2139 kInputData = 0,
2140 kInputPrevActivation = 1,
2141 kInputWeights = 2,
2142 kInputBiases = 3,
2143 kInputPrevState = 4,
2144 kInputNum = 5,
2145};
2146
2147enum OutputTensor {
2148 kOutputActivation = 0,
2149 kOutputState = 1,
2150 kOutputConcatTemp = 2,
2151 kOutputActivationTemp = 3,
2152 kOutputNum = 4,
2153};
2154
2155void* Init(TfLiteContext* context, const char* buffer, size_t length) {
2156 auto* op_data = new OpData();
2157 op_data->kernel_type = kTfLiteLSTMBasicKernel;
2158 // `scratch_tensor_index` is unused in this kernel.
2159 op_data->scratch_tensor_index = -1;
2160 return op_data;
2161}
2162
2163TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
2164 TF_LITE_ENSURE(context, node->inputs->size == kInputNum);
2165 TF_LITE_ENSURE(context, node->outputs->size == kOutputNum);
2166
2167 const TfLiteTensor* input;
2168 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputData, &input));
2169 const TfLiteTensor* prev_activation;
2170 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputPrevActivation,
2171 &prev_activation));
2172 const TfLiteTensor* weights;
2173 TF_LITE_ENSURE_OK(context,
2174 GetInputSafe(context, node, kInputWeights, &weights));
2175 const TfLiteTensor* bias;
2176 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputBiases, &bias));
2177 const TfLiteTensor* prev_state;
2178 TF_LITE_ENSURE_OK(context,
2179 GetInputSafe(context, node, kInputPrevState, &prev_state));
2180
2181 TF_LITE_ENSURE_EQ(context, input->dims->size, 2);
2182 const int num_batches = input->dims->data[0];
2183 const int input_depth = input->dims->data[1];
2184
2185 TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2);
2186 TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches);
2187 const int activation_depth = prev_activation->dims->data[1];
2188 const int total_depth = input_depth + activation_depth;
2189
2190 TF_LITE_ENSURE_EQ(context, weights->dims->size, 2);
2191 TF_LITE_ENSURE_EQ(context, weights->dims->data[0], 4 * activation_depth);
2192 TF_LITE_ENSURE_EQ(context, weights->dims->data[1], total_depth);
2193
2194 TF_LITE_ENSURE_EQ(context, bias->dims->size, 1);
2195 TF_LITE_ENSURE_EQ(context, bias->dims->data[0], 4 * activation_depth);
2196
2197 TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2);
2198 TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches);
2199 TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth);
2200
2201 TfLiteTensor* activation_out;
2202 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivation,
2203 &activation_out));
2204 TfLiteTensor* state_out;
2205 TF_LITE_ENSURE_OK(context,
2206 GetOutputSafe(context, node, kOutputState, &state_out));
2207 TfLiteTensor* concat_temp;
2208 TF_LITE_ENSURE_OK(
2209 context, GetOutputSafe(context, node, kOutputConcatTemp, &concat_temp));
2210 TfLiteTensor* activation_temp;
2211 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivationTemp,
2212 &activation_temp));
2213
2214 TF_LITE_ENSURE_OK(context, context->ResizeTensor(
2215 context, activation_out,
2216 TfLiteIntArrayCopy(prev_activation->dims)));
2217 TF_LITE_ENSURE_OK(
2218 context, context->ResizeTensor(context, state_out,
2219 TfLiteIntArrayCopy(prev_state->dims)));
2220
2221 TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2);
2222 concat_temp_size->data[0] = num_batches;
2223 concat_temp_size->data[1] = total_depth;
2224 TF_LITE_ENSURE_OK(
2225 context, context->ResizeTensor(context, concat_temp, concat_temp_size));
2226 TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2);
2227 activation_temp_size->data[0] = num_batches;
2228 activation_temp_size->data[1] = 4 * activation_depth;
2229 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp,
2230 activation_temp_size));
2231
2232 // Set the state tensors as persistent.
2233 for (auto index : {kInputPrevActivation, kInputPrevState}) {
2234 TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
2235 tensor->allocation_type = kTfLiteArenaRwPersistent;
2236 }
2237 return kTfLiteOk;
2238}
2239
2240TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
2241 const TfLiteTensor* input;
2242 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputData, &input));
2243 const TfLiteTensor* prev_activation;
2244 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputPrevActivation,
2245 &prev_activation));
2246 const TfLiteTensor* weights;
2247 TF_LITE_ENSURE_OK(context,
2248 GetInputSafe(context, node, kInputWeights, &weights));
2249 const TfLiteTensor* bias;
2250 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputBiases, &bias));
2251 const TfLiteTensor* prev_state;
2252 TF_LITE_ENSURE_OK(context,
2253 GetInputSafe(context, node, kInputPrevState, &prev_state));
2254
2255 TfLiteTensor* activation_out;
2256 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivation,
2257 &activation_out));
2258 TfLiteTensor* state_out;
2259 TF_LITE_ENSURE_OK(context,
2260 GetOutputSafe(context, node, kOutputState, &state_out));
2261 TfLiteTensor* concat_temp;
2262 TF_LITE_ENSURE_OK(
2263 context, GetOutputSafe(context, node, kOutputConcatTemp, &concat_temp));
2264 TfLiteTensor* activation_temp;
2265 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivationTemp,
2266 &activation_temp));
2267
2268 if (input->type == kTfLiteFloat32 &&
2269 prev_activation->type == kTfLiteFloat32 &&
2270 weights->type == kTfLiteFloat32 && bias->type == kTfLiteFloat32 &&
2271 prev_state->type == kTfLiteFloat32 && state_out->type == kTfLiteFloat32 &&
2272 activation_out->type == kTfLiteFloat32 &&
2273 concat_temp->type == kTfLiteFloat32 &&
2274 activation_temp->type == kTfLiteFloat32) {
2275 tflite::LstmCellParams op_params;
2276 // Float LSTM cell does not need parameters to be set: leave untouched.
2277 optimized_ops::LstmCell(
2278 op_params,
2279 // Inputs.
2280 GetTensorShape(input), GetTensorData<float>(input),
2281 GetTensorShape(prev_activation), GetTensorData<float>(prev_activation),
2282 GetTensorShape(weights), GetTensorData<float>(weights),
2283 GetTensorShape(bias), GetTensorData<float>(bias),
2284 GetTensorShape(prev_state), GetTensorData<float>(prev_state),
2285 // Outputs.
2286 GetTensorShape(state_out), GetTensorData<float>(state_out),
2287 GetTensorShape(activation_out), GetTensorData<float>(activation_out),
2288 GetTensorShape(concat_temp), GetTensorData<float>(concat_temp),
2289 GetTensorShape(activation_temp), GetTensorData<float>(activation_temp),
2290 CpuBackendContext::GetFromContext(context));
2291 } else if (input->type == kTfLiteUInt8 &&
2292 prev_activation->type == kTfLiteUInt8 &&
2293 weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 &&
2294 prev_state->type == kTfLiteInt16 &&
2295 state_out->type == kTfLiteInt16 &&
2296 activation_out->type == kTfLiteUInt8 &&
2297 concat_temp->type == kTfLiteUInt8 &&
2298 activation_temp->type == kTfLiteInt16) {
2299 int state_scale_log2_rounded;
2300 if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) {
2301 TF_LITE_KERNEL_LOG(
2302 context,
2303 "The internal state of a LSTM cell must have a power-of-two scale.");
2304 return kTfLiteError;
2305 }
2306 const int state_integer_bits = 15 + state_scale_log2_rounded;
2307 if (state_integer_bits != 4) {
2308 TF_LITE_KERNEL_LOG(context,
2309 "The only case of quantized LstmCell currently "
2310 "supported is with StateIntegerBits==4");
2311 return kTfLiteError;
2312 }
2313
2314 double real_accum_multiplier = 4096 * bias->params.scale;
2315 int32 accum_multiplier;
2316 int accum_shift;
2317 tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier,
2318 &accum_shift);
2319 tflite::LstmCellParams op_params;
2320 op_params.weights_zero_point = weights->params.zero_point;
2321 op_params.accum_multiplier = accum_multiplier;
2322 op_params.accum_shift = accum_shift;
2323 optimized_ops::LstmCell<4>(
2324 op_params,
2325 // Inputs.
2326 GetTensorShape(input), GetTensorData<uint8_t>(input),
2327 GetTensorShape(prev_activation),
2328 GetTensorData<uint8_t>(prev_activation), GetTensorShape(weights),
2329 GetTensorData<uint8_t>(weights), GetTensorShape(bias),
2330 GetTensorData<int32_t>(bias), GetTensorShape(prev_state),
2331 GetTensorData<int16_t>(prev_state),
2332 // Outputs.
2333 GetTensorShape(state_out), GetTensorData<int16_t>(state_out),
2334 GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out),
2335 GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp),
2336 GetTensorShape(activation_temp),
2337 GetTensorData<int16_t>(activation_temp),
2338 CpuBackendContext::GetFromContext(context));
2339 } else {
2340 TF_LITE_KERNEL_LOG(context,
2341 "Unsupported combination of data types for LstmCell");
2342 return kTfLiteError;
2343 }
2344
2345 memcpy(prev_activation->data.raw, activation_out->data.raw,
2346 activation_out->bytes);
2347 memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes);
2348
2349 return kTfLiteOk;
2350}
2351
2352} // namespace basic
2353
2354void* Init(TfLiteContext* context, const char* buffer, size_t length) {
2355 const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer);
2356 switch (params->kernel_type) {
2357 case kTfLiteLSTMFullKernel:
2358 return full::Init(context, buffer, length);
2359 case kTfLiteLSTMBasicKernel:
2360 return basic::Init(context, buffer, length);
2361 default:
2362 return nullptr;
2363 }
2364}
2365void Free(TfLiteContext* context, void* buffer) {
2366 delete static_cast<OpData*>(buffer);
2367}
2368
2369TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
2370 const auto* op_data = static_cast<const OpData*>(node->user_data);
2371 switch (op_data->kernel_type) {
2372 case kTfLiteLSTMFullKernel:
2373 return full::Prepare(context, node);
2374 case kTfLiteLSTMBasicKernel:
2375 return basic::Prepare(context, node);
2376 default:
2377 return kTfLiteError;
2378 }
2379}
2380
2381TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
2382 const auto* op_data = static_cast<const OpData*>(node->user_data);
2383 switch (op_data->kernel_type) {
2384 case kTfLiteLSTMFullKernel:
2385 return full::Eval(context, node);
2386 case kTfLiteLSTMBasicKernel:
2387 return basic::Eval(context, node);
2388 default:
2389 return kTfLiteError;
2390 }
2391}
2392
2393} // namespace lstm
2394
2395TfLiteRegistration* Register_LSTM() {
2396 static TfLiteRegistration r = {lstm::Init, lstm::Free, lstm::Prepare,
2397 lstm::Eval};
2398 return &r;
2399}
2400
2401} // namespace builtin
2402} // namespace ops
2403} // namespace tflite
2404