1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <math.h>
17
18#include <algorithm>
19#include <cstddef>
20
21#include "tensorflow/lite/c/builtin_op_data.h"
22#include "tensorflow/lite/c/common.h"
23#include "tensorflow/lite/kernels/cpu_backend_context.h"
24#include "tensorflow/lite/kernels/internal/compatibility.h"
25#include "tensorflow/lite/kernels/internal/kernel_utils.h"
26#include "tensorflow/lite/kernels/internal/quantization_util.h"
27#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
28#include "tensorflow/lite/kernels/internal/tensor_utils.h"
29#include "tensorflow/lite/kernels/kernel_util.h"
30#include "tensorflow/lite/kernels/lstm_eval.h"
31#include "tensorflow/lite/kernels/lstm_shared.h"
32
33namespace tflite {
34namespace ops {
35namespace builtin {
36namespace unidirectional_sequence_lstm {
37namespace {
38
39struct OpData {
40 // If the lstm is layer norm.
41 bool use_layer_norm;
42 // The scratch tensor index.
43 int scratch_tensor_index;
44 bool compute_row_sums = false;
45
46 lstm_eval::IntegerLstmParameter integer_lstm_param;
47};
48
49TfLiteStatus PopulateQuantizedLstmParams8x8_16(
50 TfLiteContext* context, TfLiteNode* node,
51 lstm_eval::IntegerLstmParameter* integer_lstm_param) {
52 // Calculate quantized clip for projection and cell.
53 const auto* params =
54 static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(node->builtin_data);
55 const float cell_clip = params->cell_clip;
56 const float proj_clip = params->proj_clip;
57
58 const TfLiteTensor* cell_state =
59 GetVariableInput(context, node, lstm::full::kCellStateTensor);
60 TF_LITE_ENSURE(context, cell_state != nullptr);
61 TfLiteTensor* output_tensor;
62 TF_LITE_ENSURE_OK(
63 context,
64 GetOutputSafe(context, node, lstm::full::kOutputTensor, &output_tensor));
65
66 TF_LITE_ENSURE(context,
67 cell_state->quantization.type != kTfLiteNoQuantization);
68 auto* cell_state_params =
69 static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params);
70 TF_LITE_ENSURE(context,
71 output_tensor->quantization.type != kTfLiteNoQuantization);
72 auto* proj_params = static_cast<TfLiteAffineQuantization*>(
73 output_tensor->quantization.params);
74 if (cell_clip > 0.0) {
75 integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min(
76 std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f),
77 32767.0f));
78 } else {
79 integer_lstm_param->quantized_cell_clip = 0;
80 }
81 if (proj_clip > 0.0) {
82 integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min(
83 std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f));
84 } else {
85 integer_lstm_param->quantized_proj_clip = 0;
86 }
87
88 // Calculate effective scales.
89 OpData* op_data = static_cast<OpData*>(node->user_data);
90 const bool use_layer_norm = op_data->use_layer_norm;
91
92 const TfLiteTensor* input;
93 TF_LITE_ENSURE_OK(
94 context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
95
96 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
97 context, node, lstm::full::kInputToInputWeightsTensor);
98 const TfLiteTensor* input_to_forget_weights;
99 TF_LITE_ENSURE_OK(
100 context,
101 GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
102 &input_to_forget_weights));
103 const TfLiteTensor* input_to_cell_weights;
104 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
105 lstm::full::kInputToCellWeightsTensor,
106 &input_to_cell_weights));
107 const TfLiteTensor* input_to_output_weights;
108 TF_LITE_ENSURE_OK(
109 context,
110 GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
111 &input_to_output_weights));
112
113 const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
114 context, node, lstm::full::kRecurrentToInputWeightsTensor);
115 const TfLiteTensor* recurrent_to_forget_weights;
116 TF_LITE_ENSURE_OK(
117 context,
118 GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
119 &recurrent_to_forget_weights));
120 const TfLiteTensor* recurrent_to_cell_weights;
121 TF_LITE_ENSURE_OK(
122 context,
123 GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
124 &recurrent_to_cell_weights));
125 const TfLiteTensor* recurrent_to_output_weights;
126 TF_LITE_ENSURE_OK(
127 context,
128 GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
129 &recurrent_to_output_weights));
130
131 const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
132 context, node, lstm::full::kCellToInputWeightsTensor);
133 const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
134 context, node, lstm::full::kCellToForgetWeightsTensor);
135 const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
136 context, node, lstm::full::kCellToOutputWeightsTensor);
137
138 const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
139 context, node, lstm::full::kInputLayerNormCoefficientsTensor);
140 const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
141 context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
142 const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor(
143 context, node, lstm::full::kCellLayerNormCoefficientsTensor);
144 const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor(
145 context, node, lstm::full::kOutputLayerNormCoefficientsTensor);
146
147 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
148 context, node, lstm::full::kProjectionWeightsTensor);
149
150 TfLiteTensor* output_state =
151 GetVariableInput(context, node, lstm::full::kOutputStateTensor);
152 TF_LITE_ENSURE(context, output_state != nullptr);
153
154 // Since we have already checked that weights are all there or none, we can
155 // check the existence of only one to get the condition.
156 const bool use_cifg = (input_to_input_weights == nullptr);
157 const bool use_peephole = (cell_to_output_weights != nullptr);
158 const bool use_projection = (projection_weights != nullptr);
159
160 // Get intermediate scales and zero points.
161 std::vector<float> intermediate_scale;
162 std::vector<int32> intermediate_zp;
163 for (int i = 0; i < 4; ++i) {
164 if (use_layer_norm) {
165 TfLiteTensor* intermediate;
166 TF_LITE_ENSURE_OK(context,
167 GetIntermediatesSafe(context, node, i, &intermediate));
168 TF_LITE_ENSURE(context,
169 intermediate->quantization.type != kTfLiteNoQuantization);
170 auto* params = static_cast<TfLiteAffineQuantization*>(
171 intermediate->quantization.params);
172 intermediate_scale.push_back(params->scale->data[0]);
173 intermediate_zp.push_back(params->zero_point->data[0]);
174 } else {
175 // Q3.12 for activation functions.
176 intermediate_scale.push_back(std::pow(2, -12));
177 intermediate_zp.push_back(0);
178 }
179 }
180 // In the absence of projection, hidden becomes otuput and this intermediate
181 // is ignored.
182 TfLiteTensor* hidden;
183 TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden));
184 TF_LITE_ENSURE(context, hidden->quantization.type != kTfLiteNoQuantization);
185 auto* hidden_params =
186 static_cast<TfLiteAffineQuantization*>(hidden->quantization.params);
187 intermediate_scale.push_back(hidden_params->scale->data[0]);
188 intermediate_zp.push_back(hidden_params->zero_point->data[0]);
189
190 // Scales.
191 const float default_scale = 1.0;
192 float input_scale = default_scale;
193 float input_to_input_weight_scale = default_scale;
194 float recurrent_to_input_weight_scale = default_scale;
195 float cell_to_input_weight_scale = default_scale;
196 float input_to_forget_weight_scale = default_scale;
197 float recurrent_to_forget_weight_scale = default_scale;
198 float cell_to_forget_weight_scale = default_scale;
199 float input_to_cell_weight_scale = default_scale;
200 float recurrent_to_cell_weight_scale = default_scale;
201 float input_to_output_weight_scale = default_scale;
202 float recurrent_to_output_weight_scale = default_scale;
203 float cell_to_output_weight_scale = default_scale;
204 float projection_weight_scale = default_scale;
205 float layer_norm_input_scale = default_scale;
206 float layer_norm_forget_scale = default_scale;
207 float layer_norm_cell_scale = default_scale;
208 float layer_norm_output_scale = default_scale;
209 float output_state_scale = default_scale;
210 int cell_scale = 1;
211
212 // Effective scales.
213 float effective_input_to_input_scale = default_scale;
214 float effective_recurrent_to_input_scale = default_scale;
215 float effective_cell_to_input_scale = default_scale;
216 float effective_input_to_forget_scale = default_scale;
217 float effective_recurrent_to_forget_scale = default_scale;
218 float effective_cell_to_forget_scale = default_scale;
219 float effective_input_to_cell_scale = default_scale;
220 float effective_recurrent_to_cell_scale = default_scale;
221 float effective_input_to_output_scale = default_scale;
222 float effective_recurrent_to_output_scale = default_scale;
223 float effective_cell_to_output_scale = default_scale;
224 float effective_proj_scale = default_scale;
225 float effective_hidden_scale = default_scale;
226
227 // Populate scales.
228 if (!use_cifg) {
229 input_to_input_weight_scale = input_to_input_weights->params.scale;
230 recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale;
231 }
232
233 if (use_peephole) {
234 if (!use_cifg) {
235 cell_to_input_weight_scale = cell_to_input_weights->params.scale;
236 }
237 cell_to_forget_weight_scale = cell_to_forget_weights->params.scale;
238 cell_to_output_weight_scale = cell_to_output_weights->params.scale;
239 }
240
241 if (use_layer_norm) {
242 if (!use_cifg) {
243 layer_norm_input_scale = input_layer_norm_coefficients->params.scale;
244 }
245 layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale;
246 layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale;
247 layer_norm_output_scale = output_layer_norm_coefficients->params.scale;
248 }
249
250 if (use_projection) {
251 projection_weight_scale = projection_weights->params.scale;
252 }
253 output_state_scale = output_state->params.scale;
254
255 input_to_forget_weight_scale = input_to_forget_weights->params.scale;
256 input_to_cell_weight_scale = input_to_cell_weights->params.scale;
257 input_to_output_weight_scale = input_to_output_weights->params.scale;
258 recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale;
259 recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale;
260 recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale;
261
262 // Check cell state (already used above)
263 TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale));
264 // TF_LITE_ENSURE(context, cell_scale <= -9);
265 integer_lstm_param->cell_scale = cell_scale;
266 input_scale = input->params.scale;
267
268 // Calculate effective scales.
269 if (!use_cifg) {
270 effective_input_to_input_scale =
271 input_to_input_weight_scale * input_scale / intermediate_scale[0];
272 effective_recurrent_to_input_scale = recurrent_to_input_weight_scale *
273 output_state_scale /
274 intermediate_scale[0];
275 }
276 effective_input_to_forget_scale =
277 input_to_forget_weight_scale * input_scale / intermediate_scale[1];
278 effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale *
279 output_state_scale /
280 intermediate_scale[1];
281
282 effective_input_to_cell_scale =
283 input_to_cell_weight_scale * input_scale / intermediate_scale[2];
284 effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale *
285 output_state_scale /
286 intermediate_scale[2];
287
288 effective_input_to_output_scale =
289 input_to_output_weight_scale * input_scale / intermediate_scale[3];
290 effective_recurrent_to_output_scale = recurrent_to_output_weight_scale *
291 output_state_scale /
292 intermediate_scale[3];
293
294 effective_hidden_scale =
295 std::pow(2, -15) / intermediate_scale[4] * std::pow(2, -15);
296
297 effective_proj_scale =
298 projection_weight_scale * intermediate_scale[4] / output_state_scale;
299
300 if (use_peephole) {
301 if (!use_cifg) {
302 effective_cell_to_input_scale = std::pow(2, cell_scale) * // NOLINT
303 cell_to_input_weight_scale /
304 intermediate_scale[0];
305 }
306 effective_cell_to_forget_scale = std::pow(2, cell_scale) * // NOLINT
307 cell_to_forget_weight_scale /
308 intermediate_scale[1];
309 effective_cell_to_output_scale = std::pow(2, cell_scale) * // NOLINT
310 cell_to_output_weight_scale /
311 intermediate_scale[3];
312 }
313
314 // Decompose scales.
315 QuantizeMultiplier(effective_input_to_input_scale,
316 &integer_lstm_param->effective_input_to_input_scale_a,
317 &integer_lstm_param->effective_input_to_input_scale_b);
318 QuantizeMultiplier(effective_recurrent_to_input_scale,
319 &integer_lstm_param->effective_recurrent_to_input_scale_a,
320 &integer_lstm_param->effective_recurrent_to_input_scale_b);
321 QuantizeMultiplier(effective_cell_to_input_scale,
322 &integer_lstm_param->effective_cell_to_input_scale_a,
323 &integer_lstm_param->effective_cell_to_input_scale_b);
324 QuantizeMultiplier(effective_input_to_forget_scale,
325 &integer_lstm_param->effective_input_to_forget_scale_a,
326 &integer_lstm_param->effective_input_to_forget_scale_b);
327 QuantizeMultiplier(
328 effective_recurrent_to_forget_scale,
329 &integer_lstm_param->effective_recurrent_to_forget_scale_a,
330 &integer_lstm_param->effective_recurrent_to_forget_scale_b);
331 QuantizeMultiplier(effective_cell_to_forget_scale,
332 &integer_lstm_param->effective_cell_to_forget_scale_a,
333 &integer_lstm_param->effective_cell_to_forget_scale_b);
334 QuantizeMultiplier(effective_input_to_cell_scale,
335 &integer_lstm_param->effective_input_to_cell_scale_a,
336 &integer_lstm_param->effective_input_to_cell_scale_b);
337 QuantizeMultiplier(effective_recurrent_to_cell_scale,
338 &integer_lstm_param->effective_recurrent_to_cell_scale_a,
339 &integer_lstm_param->effective_recurrent_to_cell_scale_b);
340 QuantizeMultiplier(effective_input_to_output_scale,
341 &integer_lstm_param->effective_input_to_output_scale_a,
342 &integer_lstm_param->effective_input_to_output_scale_b);
343 QuantizeMultiplier(
344 effective_recurrent_to_output_scale,
345 &integer_lstm_param->effective_recurrent_to_output_scale_a,
346 &integer_lstm_param->effective_recurrent_to_output_scale_b);
347 QuantizeMultiplier(effective_cell_to_output_scale,
348 &integer_lstm_param->effective_cell_to_output_scale_a,
349 &integer_lstm_param->effective_cell_to_output_scale_b);
350 QuantizeMultiplier(effective_proj_scale,
351 &integer_lstm_param->effective_proj_scale_a,
352 &integer_lstm_param->effective_proj_scale_b);
353 QuantizeMultiplier(effective_hidden_scale,
354 &integer_lstm_param->effective_hidden_scale_a,
355 &integer_lstm_param->effective_hidden_scale_b);
356 QuantizeMultiplier(layer_norm_input_scale,
357 &integer_lstm_param->layer_norm_input_scale_a,
358 &integer_lstm_param->layer_norm_input_scale_b);
359 QuantizeMultiplier(layer_norm_forget_scale,
360 &integer_lstm_param->layer_norm_forget_scale_a,
361 &integer_lstm_param->layer_norm_forget_scale_b);
362 QuantizeMultiplier(layer_norm_cell_scale,
363 &integer_lstm_param->layer_norm_cell_scale_a,
364 &integer_lstm_param->layer_norm_cell_scale_b);
365 QuantizeMultiplier(layer_norm_output_scale,
366 &integer_lstm_param->layer_norm_output_scale_a,
367 &integer_lstm_param->layer_norm_output_scale_b);
368
369 integer_lstm_param->hidden_zp = intermediate_zp[4];
370
371 // 10000 is used to make sure the kernel logic does not overflow.
372 if (!use_cifg) {
373 integer_lstm_param->input_variance_guard =
374 std::max(1, static_cast<int32_t>(10000 * layer_norm_input_scale));
375 }
376 integer_lstm_param->forget_variance_guard =
377 std::max(1, static_cast<int32_t>(10000 * layer_norm_forget_scale));
378 integer_lstm_param->cell_variance_guard =
379 std::max(1, static_cast<int32_t>(10000 * layer_norm_cell_scale));
380 integer_lstm_param->output_variance_guard =
381 std::max(1, static_cast<int32_t>(10000 * layer_norm_output_scale));
382
383 return kTfLiteOk;
384}
385
386} // namespace
387
388// Temporary tensors
389enum TemporaryTensor {
390 kScratchBuffer = 0,
391 kInputQuantized = 1,
392 kOutputStateQuantized = 2,
393 kCellStateQuantized = 3,
394 kInputScalingFactors = 4,
395 kOutputStateScalingFactors = 5,
396 kProductScalingFactors = 6,
397 kRecoveredCellWeights = 7,
398 kAccumScratch = 8,
399 kInputZeroPoints = 9,
400 kOutputStateZeroPoints = 10,
401 kRowSums = 11,
402 kNumTemporaryTensors = 12,
403};
404
405void* Init(TfLiteContext* context, const char* buffer, size_t length) {
406 auto* op_data = new OpData();
407 context->AddTensors(context, kNumTemporaryTensors,
408 &op_data->scratch_tensor_index);
409 return op_data;
410}
411
412void Free(TfLiteContext* context, void* buffer) {
413 delete reinterpret_cast<OpData*>(buffer);
414}
415
416// Check that input tensor dimensions matches with each other.
417TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
418 TfLiteNode* node, int n_input,
419 int n_output, int n_cell,
420 bool use_layer_norm, bool is_integer) {
421 const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
422
423 // Making sure clipping parameters have valid values.
424 // == 0 means no clipping
425 // > 0 means clipping
426 TF_LITE_ENSURE(context, params->cell_clip >= 0);
427 TF_LITE_ENSURE(context, params->proj_clip >= 0);
428
429 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
430 context, node, lstm::full::kInputToInputWeightsTensor);
431 if (input_to_input_weights != nullptr) {
432 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
433 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
434 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
435 }
436
437 const TfLiteTensor* input_to_forget_weights;
438 TF_LITE_ENSURE_OK(
439 context,
440 GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
441 &input_to_forget_weights));
442 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
443 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
444 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
445
446 const TfLiteTensor* input_to_cell_weights;
447 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
448 lstm::full::kInputToCellWeightsTensor,
449 &input_to_cell_weights));
450 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
451 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
452 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
453
454 const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
455 context, node, lstm::full::kRecurrentToInputWeightsTensor);
456 if (recurrent_to_input_weights != nullptr) {
457 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
458 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
459 n_cell);
460 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
461 n_output);
462 }
463
464 const TfLiteTensor* recurrent_to_forget_weights;
465 TF_LITE_ENSURE_OK(
466 context,
467 GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
468 &recurrent_to_forget_weights));
469 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
470 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
471 n_cell);
472 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
473 n_output);
474
475 const TfLiteTensor* recurrent_to_cell_weights;
476 TF_LITE_ENSURE_OK(
477 context,
478 GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
479 &recurrent_to_cell_weights));
480 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
481 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
482 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
483 n_output);
484
485 // We make sure the input-gate's parameters are either both present (regular
486 // LSTM) or not at all (CIFG-LSTM).
487 const bool cifg_weights_all_or_none =
488 ((input_to_input_weights != nullptr) &&
489 (recurrent_to_input_weights != nullptr)) ||
490 ((input_to_input_weights == nullptr) &&
491 (recurrent_to_input_weights == nullptr));
492 TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
493
494 const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
495 context, node, lstm::full::kCellToInputWeightsTensor);
496 if (cell_to_input_weights != nullptr) {
497 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
498 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
499 TF_LITE_ENSURE_TYPES_EQ(
500 context, cell_to_input_weights->type,
501 is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
502 }
503
504 const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
505 context, node, lstm::full::kCellToForgetWeightsTensor);
506 if (cell_to_forget_weights != nullptr) {
507 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
508 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
509 TF_LITE_ENSURE_TYPES_EQ(
510 context, cell_to_forget_weights->type,
511 is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
512 }
513
514 const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
515 context, node, lstm::full::kCellToOutputWeightsTensor);
516 if (cell_to_output_weights != nullptr) {
517 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
518 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
519 TF_LITE_ENSURE_TYPES_EQ(
520 context, cell_to_output_weights->type,
521 is_integer ? kTfLiteInt16 : input_to_forget_weights->type);
522 }
523
524 // Making sure the peephole weights are there all or none.
525 const bool use_cifg = (input_to_input_weights == nullptr);
526 const bool peephole_weights_all_or_none =
527 ((cell_to_input_weights != nullptr || use_cifg) &&
528 (cell_to_forget_weights != nullptr) &&
529 (cell_to_output_weights != nullptr)) ||
530 ((cell_to_input_weights == nullptr) &&
531 (cell_to_forget_weights == nullptr) &&
532 (cell_to_output_weights == nullptr));
533 TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
534
535 // Make sure the input gate bias is present only when not a CIFG-LSTM.
536 const TfLiteTensor* input_gate_bias =
537 GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
538 if (use_cifg) {
539 TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
540 } else {
541 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
542 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
543 if (is_integer) {
544 TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32);
545 } else {
546 TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
547 }
548 }
549
550 const TfLiteTensor* forget_gate_bias;
551 TF_LITE_ENSURE_OK(
552 context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
553 &forget_gate_bias));
554 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
555 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
556 if (is_integer) {
557 TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32);
558 } else {
559 TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
560 }
561
562 const TfLiteTensor* cell_gate_bias;
563 TF_LITE_ENSURE_OK(context,
564 GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
565 &cell_gate_bias));
566 TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
567 TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
568 if (is_integer) {
569 TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32);
570 } else {
571 TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
572 }
573
574 const TfLiteTensor* output_gate_bias;
575 TF_LITE_ENSURE_OK(
576 context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
577 &output_gate_bias));
578 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
579 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
580 if (is_integer) {
581 TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32);
582 } else {
583 TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
584 }
585
586 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
587 context, node, lstm::full::kProjectionWeightsTensor);
588 if (projection_weights != nullptr) {
589 TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
590 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
591 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
592 }
593
594 const TfLiteTensor* projection_bias =
595 GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
596 if (projection_bias != nullptr) {
597 TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
598 TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
599 if (is_integer) {
600 TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32);
601 } else {
602 TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
603 }
604 }
605
606 // Making sure the projection tensors are consistent:
607 // 1) If projection weight is not present, then projection bias should not be
608 // present.
609 // 2) If projection weight is present, then projection bias is optional.
610 // TODO(ghodrat): make sure this is correct.
611 const bool projecton_tensors_consistent =
612 ((projection_weights != nullptr) || (projection_bias == nullptr));
613 TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
614
615 if (use_layer_norm) {
616 const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor(
617 context, node, lstm::full::kInputLayerNormCoefficientsTensor);
618 if (use_cifg) {
619 TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr);
620 } else {
621 TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr);
622 TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1);
623 TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0],
624 n_cell);
625 if (is_integer) {
626 TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
627 kTfLiteInt16);
628 } else {
629 TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type,
630 kTfLiteFloat32);
631 }
632 }
633
634 const TfLiteTensor* forget_layer_norm_coefficients;
635 TF_LITE_ENSURE_OK(
636 context, GetInputSafe(context, node,
637 lstm::full::kForgetLayerNormCoefficientsTensor,
638 &forget_layer_norm_coefficients));
639 TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1);
640 TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0],
641 n_cell);
642 if (is_integer) {
643 TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
644 kTfLiteInt16);
645 } else {
646 TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type,
647 kTfLiteFloat32);
648 }
649
650 const TfLiteTensor* cell_layer_norm_coefficients;
651 TF_LITE_ENSURE_OK(context,
652 GetInputSafe(context, node,
653 lstm::full::kCellLayerNormCoefficientsTensor,
654 &cell_layer_norm_coefficients));
655 TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1);
656 TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0],
657 n_cell);
658 if (is_integer) {
659 TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
660 kTfLiteInt16);
661 } else {
662 TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type,
663 kTfLiteFloat32);
664 }
665
666 const TfLiteTensor* output_layer_norm_coefficients;
667 TF_LITE_ENSURE_OK(
668 context, GetInputSafe(context, node,
669 lstm::full::kOutputLayerNormCoefficientsTensor,
670 &output_layer_norm_coefficients));
671 TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1);
672 TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0],
673 n_cell);
674 if (is_integer) {
675 TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
676 kTfLiteInt16);
677 } else {
678 TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type,
679 kTfLiteFloat32);
680 }
681 }
682
683 return kTfLiteOk;
684}
685
686TfLiteStatus PrecomputeZeroPointTimesWeightWithBias(
687 TfLiteContext* context, int32_t zero_point,
688 const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor,
689 std::unique_ptr<int32_t[]>* output) {
690 if (weight_tensor == nullptr) {
691 return kTfLiteOk;
692 }
693
694 const RuntimeShape& weight_shape = GetTensorShape(weight_tensor);
695 TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2);
696 const int row = weight_shape.Dims(0);
697 const int col = weight_shape.Dims(1);
698 output->reset(new int32_t[row]);
699 if (bias_tensor == nullptr) {
700 memset(output->get(), 0, row * sizeof(int32_t));
701 } else {
702 const int32_t* bias = GetTensorData<int32_t>(bias_tensor);
703 memcpy(output->get(), bias, row * sizeof(int32_t));
704 }
705 if (zero_point != 0) {
706 const int8_t* weight = GetTensorData<int8_t>(weight_tensor);
707 tensor_utils::MatrixScalarMultiplyAccumulate(weight, zero_point, row, col,
708 output->get());
709 }
710 return kTfLiteOk;
711}
712
713TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context,
714 OpData* op_data,
715 TfLiteNode* node) {
716 const TfLiteTensor* input;
717 TF_LITE_ENSURE_OK(
718 context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
719 const TfLiteTensor* output_state =
720 GetVariableInput(context, node, lstm::full::kOutputStateTensor);
721 TF_LITE_ENSURE(context, output_state != nullptr);
722
723 const int32_t input_zero_point = -input->params.zero_point;
724 const int32_t output_state_zero_point = -output_state->params.zero_point;
725
726 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
727 context, node, lstm::full::kInputToInputWeightsTensor);
728 const TfLiteTensor* input_to_forget_weights;
729 TF_LITE_ENSURE_OK(
730 context,
731 GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
732 &input_to_forget_weights));
733 const TfLiteTensor* input_to_cell_weights;
734 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
735 lstm::full::kInputToCellWeightsTensor,
736 &input_to_cell_weights));
737 const TfLiteTensor* input_to_output_weights;
738 TF_LITE_ENSURE_OK(
739 context,
740 GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
741 &input_to_output_weights));
742
743 const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
744 context, node, lstm::full::kRecurrentToInputWeightsTensor);
745 const TfLiteTensor* recurrent_to_forget_weights;
746 TF_LITE_ENSURE_OK(
747 context,
748 GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
749 &recurrent_to_forget_weights));
750 const TfLiteTensor* recurrent_to_cell_weights;
751 TF_LITE_ENSURE_OK(
752 context,
753 GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
754 &recurrent_to_cell_weights));
755 const TfLiteTensor* recurrent_to_output_weights;
756 TF_LITE_ENSURE_OK(
757 context,
758 GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
759 &recurrent_to_output_weights));
760
761 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
762 context, node, lstm::full::kProjectionWeightsTensor);
763 const TfLiteTensor* projection_bias =
764 GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
765
766 lstm_eval::IntegerLstmParameter* integer_lstm_params =
767 &op_data->integer_lstm_param;
768
769 const TfLiteTensor* intermediate =
770 &context->tensors[node->intermediates->data[4]];
771 TF_LITE_ENSURE(context,
772 intermediate->quantization.type != kTfLiteNoQuantization);
773 const auto* params =
774 static_cast<TfLiteAffineQuantization*>(intermediate->quantization.params);
775 const int32_t hidden_zp = params->zero_point->data[0];
776
777 // Get bias and perform zero point calculation.
778 // When there is layer normalization, the gate bias does not apply to matmul
779 // directly:
780 // y = ln(w * x + w * r + w * c) + b.
781 const bool is_layer_norm = op_data->use_layer_norm;
782
783 // Forget gate.
784 const TfLiteTensor* forget_gate_bias =
785 is_layer_norm
786 ? nullptr
787 : GetInput(context, node, lstm::full::kForgetGateBiasTensor);
788 TF_LITE_ENSURE_OK(
789 context,
790 PrecomputeZeroPointTimesWeightWithBias(
791 context, input_zero_point, input_to_forget_weights, forget_gate_bias,
792 &(integer_lstm_params->input_to_forget_effective_bias)));
793
794 TF_LITE_ENSURE_OK(
795 context,
796 PrecomputeZeroPointTimesWeightWithBias(
797 context, output_state_zero_point, recurrent_to_forget_weights,
798 nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias)));
799
800 // Modulation gate.
801 const TfLiteTensor* cell_gate_bias =
802 is_layer_norm ? nullptr
803 : GetInput(context, node, lstm::full::kCellGateBiasTensor);
804 TF_LITE_ENSURE_OK(
805 context,
806 PrecomputeZeroPointTimesWeightWithBias(
807 context, input_zero_point, input_to_cell_weights, cell_gate_bias,
808 &(integer_lstm_params->input_to_cell_effective_bias)));
809 TF_LITE_ENSURE_OK(
810 context,
811 PrecomputeZeroPointTimesWeightWithBias(
812 context, output_state_zero_point, recurrent_to_cell_weights, nullptr,
813 &(integer_lstm_params->recurrent_to_cell_effective_bias)));
814
815 // Output gate.
816 const TfLiteTensor* output_gate_bias =
817 is_layer_norm
818 ? nullptr
819 : GetInput(context, node, lstm::full::kOutputGateBiasTensor);
820 TF_LITE_ENSURE_OK(
821 context,
822 PrecomputeZeroPointTimesWeightWithBias(
823 context, input_zero_point, input_to_output_weights, output_gate_bias,
824 &(integer_lstm_params->input_to_output_effective_bias)));
825
826 TF_LITE_ENSURE_OK(
827 context,
828 PrecomputeZeroPointTimesWeightWithBias(
829 context, output_state_zero_point, recurrent_to_output_weights,
830 nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias)));
831
832 // Input gate. The calculation is only meaningful for non-cifg case.
833 const TfLiteTensor* input_gate_bias =
834 is_layer_norm ? nullptr
835 : GetInput(context, node, lstm::full::kInputGateBiasTensor);
836 TF_LITE_ENSURE_OK(
837 context,
838 PrecomputeZeroPointTimesWeightWithBias(
839 context, input_zero_point, input_to_input_weights, input_gate_bias,
840 &(integer_lstm_params->input_to_input_effective_bias)));
841 TF_LITE_ENSURE_OK(
842 context,
843 PrecomputeZeroPointTimesWeightWithBias(
844 context, output_state_zero_point, recurrent_to_input_weights, nullptr,
845 &(integer_lstm_params->recurrent_to_input_effective_bias)));
846
847 // Projection bias. The calculation is only meaningful for with projection.
848 TF_LITE_ENSURE_OK(context,
849 PrecomputeZeroPointTimesWeightWithBias(
850 context, hidden_zp, projection_weights, projection_bias,
851 &(integer_lstm_params->projection_effective_bias)));
852 return kTfLiteOk;
853}
854
855// Resize the output and state tensors based on the sizes of the input tensors.
856// Allocate a temporary scratch tensor. Also check that the sizes of the input
857// tensors match each other.
858TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
859 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
860 const int scratch_tensor_index = op_data->scratch_tensor_index;
861
862 // Check we have all the inputs and outputs we need.
863 bool use_layer_norm = false;
864 if (node->inputs->size == 24) {
865 const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor(
866 context, node, lstm::full::kForgetLayerNormCoefficientsTensor);
867 if (forget_layer_norm_coefficients == nullptr) {
868 use_layer_norm = false;
869 } else {
870 use_layer_norm = true;
871 }
872 } else if (node->inputs->size == 20) {
873 // This is deprecated and is only kept here for backward compatibility.
874 use_layer_norm = false;
875 } else {
876 TF_LITE_KERNEL_LOG(
877 context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs",
878 node->inputs->size);
879 return kTfLiteError;
880 }
881 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
882 op_data->use_layer_norm = use_layer_norm;
883
884 // Inferring batch size, number of outputs and sequence length and
885 // number of cells from the input tensors.
886 const TfLiteTensor* input;
887 TF_LITE_ENSURE_OK(
888 context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
889 const bool is_integer = input->type == kTfLiteInt8;
890 TF_LITE_ENSURE(context, input->dims->size > 1);
891 const auto* params =
892 reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
893 node->builtin_data);
894 const bool time_major = params->time_major;
895 const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
896 const int n_input = input->dims->data[2];
897
898 const TfLiteTensor* input_to_output_weights;
899 TF_LITE_ENSURE_OK(
900 context,
901 GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
902 &input_to_output_weights));
903 const int n_cell = input_to_output_weights->dims->data[0];
904 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
905 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
906
907 const TfLiteTensor* recurrent_to_output_weights;
908 TF_LITE_ENSURE_OK(
909 context,
910 GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
911 &recurrent_to_output_weights));
912 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
913 TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
914 n_cell);
915 const int n_output = recurrent_to_output_weights->dims->data[1];
916
917 // Check that input tensor dimensions matches with each other.
918 TF_LITE_ENSURE_OK(
919 context, CheckInputTensorDimensions(context, node, n_input, n_output,
920 n_cell, use_layer_norm, is_integer));
921
922 // Get the pointer to output, output_state and cell_state buffer tensors.
923 TfLiteTensor* output;
924 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
925 lstm::full::kOutputTensor, &output));
926
927 TfLiteTensor* output_state =
928 GetVariableInput(context, node, lstm::full::kOutputStateTensor);
929 TF_LITE_ENSURE(context, output_state != nullptr);
930 TfLiteTensor* cell_state =
931 GetVariableInput(context, node, lstm::full::kCellStateTensor);
932 TF_LITE_ENSURE(context, cell_state != nullptr);
933
934 // Check the shape of input state tensors.
935 // These tensor may be 1D or 2D. It's fine as long as the total size is
936 // correct.
937 TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output);
938 TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
939
940 // Resize the output tensors.
941 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
942 output_size->data[input->dims->size - 1] = n_output;
943 TF_LITE_ENSURE_OK(context,
944 context->ResizeTensor(context, output, output_size));
945
946 if (is_integer) {
947 const int num_intermediate_tensors = node->intermediates->size;
948 TF_LITE_ENSURE(context, num_intermediate_tensors == 5);
949 }
950
951 TfLiteIntArrayFree(node->temporaries);
952 if (IsHybridOp(input, input_to_output_weights)) {
953 node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
954 } else if (is_integer) {
955 node->temporaries = TfLiteIntArrayCreate(6);
956 } else {
957 node->temporaries = TfLiteIntArrayCreate(1);
958 }
959 node->temporaries->data[kScratchBuffer] =
960 scratch_tensor_index + kScratchBuffer;
961
962 // Create a scratch buffer tensor.
963 TfLiteTensor* scratch_buffer;
964 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
965 &scratch_buffer));
966 scratch_buffer->type = input->type;
967 scratch_buffer->allocation_type = kTfLiteArenaRw;
968
969 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
970 context, node, lstm::full::kInputToInputWeightsTensor);
971 const bool use_cifg = (input_to_input_weights == nullptr);
972 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
973 scratch_buffer_size->data[0] = n_batch;
974 if (use_cifg) {
975 // Reserving space for Cell, Forget, Output gates and scratch accumulation
976 // buffer and an extra 16 bytes to avoid internal ruy copies.
977 scratch_buffer_size->data[1] = n_cell * 4 + 16;
978 } else {
979 // Reserving space for Input, Cell, Forget, Output gates and scratch
980 // accumulation buffer and an extra 16 bytes to avoid internal ruy copies.
981 scratch_buffer_size->data[1] = n_cell * 5 + 16;
982 }
983 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
984 scratch_buffer_size));
985
986 if (IsHybridOp(input, input_to_output_weights)) {
987 op_data->compute_row_sums = true;
988 // Allocate temporary tensors to store quantized values of input,
989 // output_state and cell_state tensors.
990 node->temporaries->data[kInputQuantized] =
991 scratch_tensor_index + kInputQuantized;
992 TfLiteTensor* input_quantized;
993 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
994 &input_quantized));
995 input_quantized->type = input_to_output_weights->type;
996 input_quantized->allocation_type = kTfLiteArenaRw;
997 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
998 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
999 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
1000 input_quantized_size));
1001 }
1002 node->temporaries->data[kOutputStateQuantized] =
1003 scratch_tensor_index + kOutputStateQuantized;
1004 TfLiteTensor* output_state_quantized;
1005 TF_LITE_ENSURE_OK(context,
1006 GetTemporarySafe(context, node, kOutputStateQuantized,
1007 &output_state_quantized));
1008 output_state_quantized->type = input_to_output_weights->type;
1009 output_state_quantized->allocation_type = kTfLiteArenaRw;
1010 if (!TfLiteIntArrayEqual(output_state_quantized->dims,
1011 output_state->dims)) {
1012 TfLiteIntArray* output_state_quantized_size =
1013 TfLiteIntArrayCopy(output_state->dims);
1014 TF_LITE_ENSURE_OK(context,
1015 context->ResizeTensor(context, output_state_quantized,
1016 output_state_quantized_size));
1017 }
1018 node->temporaries->data[kCellStateQuantized] =
1019 scratch_tensor_index + kCellStateQuantized;
1020 TfLiteTensor* cell_state_quantized;
1021 TF_LITE_ENSURE_OK(context,
1022 GetTemporarySafe(context, node, kCellStateQuantized,
1023 &cell_state_quantized));
1024 cell_state_quantized->type = input_to_output_weights->type;
1025 cell_state_quantized->allocation_type = kTfLiteArenaRw;
1026 if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
1027 TfLiteIntArray* cell_state_quantized_size =
1028 TfLiteIntArrayCopy(cell_state->dims);
1029 TF_LITE_ENSURE_OK(context,
1030 context->ResizeTensor(context, cell_state_quantized,
1031 cell_state_quantized_size));
1032 }
1033
1034 // Allocate temporary tensors to store scaling factors and product scaling
1035 // factors. The latter is a convenience storage which allows to quantize
1036 // a vector once (which produces the scaling factors) and multiply it with
1037 // different matrices (which requires multiplying the scaling factors with
1038 // the scaling factor of the matrix).
1039 node->temporaries->data[kInputScalingFactors] =
1040 op_data->scratch_tensor_index + kInputScalingFactors;
1041 TfLiteTensor* input_sf;
1042 TF_LITE_ENSURE_OK(
1043 context,
1044 GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
1045 input_sf->type = kTfLiteFloat32;
1046 input_sf->allocation_type = kTfLiteArenaRw;
1047 int scaling_dims[1] = {n_batch};
1048 if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
1049 TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
1050 input_sf_size->data[0] = n_batch;
1051 TF_LITE_ENSURE_OK(
1052 context, context->ResizeTensor(context, input_sf, input_sf_size));
1053 }
1054 node->temporaries->data[kOutputStateScalingFactors] =
1055 op_data->scratch_tensor_index + kOutputStateScalingFactors;
1056 TfLiteTensor* output_state_sf;
1057 TF_LITE_ENSURE_OK(
1058 context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
1059 &output_state_sf));
1060 output_state_sf->type = kTfLiteFloat32;
1061 output_state_sf->allocation_type = kTfLiteArenaRw;
1062 if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
1063 TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
1064 output_state_sf_size->data[0] = n_batch;
1065 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
1066 output_state_sf_size));
1067 }
1068 node->temporaries->data[kProductScalingFactors] =
1069 scratch_tensor_index + kProductScalingFactors;
1070 TfLiteTensor* prod_scaling_factors;
1071 TF_LITE_ENSURE_OK(context,
1072 GetTemporarySafe(context, node, kProductScalingFactors,
1073 &prod_scaling_factors));
1074 prod_scaling_factors->type = kTfLiteFloat32;
1075 prod_scaling_factors->allocation_type = kTfLiteArenaRw;
1076 if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
1077 scaling_dims)) {
1078 TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
1079 prod_scaling_factors_size->data[0] = n_batch;
1080 TF_LITE_ENSURE_OK(context,
1081 context->ResizeTensor(context, prod_scaling_factors,
1082 prod_scaling_factors_size));
1083 }
1084
1085 // Allocate a temporary tensor to store the recovered cell weights. Since
1086 // this is used for diagonal matrices, only need to store n_cell values.
1087 node->temporaries->data[kRecoveredCellWeights] =
1088 scratch_tensor_index + kRecoveredCellWeights;
1089 TfLiteTensor* recovered_cell_weights;
1090 TF_LITE_ENSURE_OK(context,
1091 GetTemporarySafe(context, node, kRecoveredCellWeights,
1092 &recovered_cell_weights));
1093 recovered_cell_weights->type = kTfLiteFloat32;
1094 recovered_cell_weights->allocation_type = kTfLiteArenaRw;
1095 int recovered_cell_dims[1] = {n_cell};
1096 if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
1097 recovered_cell_dims)) {
1098 TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
1099 recovered_cell_weights_size->data[0] = n_cell;
1100 TF_LITE_ENSURE_OK(context,
1101 context->ResizeTensor(context, recovered_cell_weights,
1102 recovered_cell_weights_size));
1103 }
1104
1105 // Allocate a temporary tensor to store the accumulated int32 values.
1106 node->temporaries->data[kAccumScratch] =
1107 scratch_tensor_index + kAccumScratch;
1108 TfLiteTensor* accum_scratch;
1109 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
1110 &accum_scratch));
1111 accum_scratch->type = kTfLiteInt32;
1112 accum_scratch->allocation_type = kTfLiteArenaRw;
1113 int accum_scratch_dims[2] = {n_cell, n_batch};
1114 if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
1115 accum_scratch_dims)) {
1116 TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
1117 accum_size->data[0] = n_cell;
1118 accum_size->data[1] = n_batch;
1119 TF_LITE_ENSURE_OK(
1120 context, context->ResizeTensor(context, accum_scratch, accum_size));
1121 }
1122 node->temporaries->data[kInputZeroPoints] =
1123 op_data->scratch_tensor_index + kInputZeroPoints;
1124 TfLiteTensor* input_zp;
1125 TF_LITE_ENSURE_OK(
1126 context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
1127 input_zp->type = kTfLiteFloat32;
1128 input_zp->allocation_type = kTfLiteArenaRw;
1129 if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
1130 TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
1131 input_zp_size->data[0] = n_batch;
1132 TF_LITE_ENSURE_OK(
1133 context, context->ResizeTensor(context, input_zp, input_zp_size));
1134 }
1135 node->temporaries->data[kOutputStateZeroPoints] =
1136 op_data->scratch_tensor_index + kOutputStateZeroPoints;
1137 TfLiteTensor* output_state_zp;
1138 TF_LITE_ENSURE_OK(context,
1139 GetTemporarySafe(context, node, kOutputStateZeroPoints,
1140 &output_state_zp));
1141 output_state_zp->type = kTfLiteFloat32;
1142 output_state_zp->allocation_type = kTfLiteArenaRw;
1143 if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
1144 TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
1145 output_state_zp_size->data[0] = n_batch;
1146 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
1147 output_state_zp_size));
1148 }
1149 node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums;
1150 TfLiteTensor* row_sums;
1151 TF_LITE_ENSURE_OK(context,
1152 GetTemporarySafe(context, node, kRowSums, &row_sums));
1153 row_sums->type = kTfLiteInt32;
1154 row_sums->name = "Lstm_row_sums";
1155 row_sums->allocation_type = kTfLiteArenaRwPersistent;
1156 int row_sums_rows = use_cifg ? 6 : 8;
1157 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
1158 context, node, lstm::full::kProjectionWeightsTensor);
1159 if (projection_weights != nullptr) {
1160 row_sums_rows += ceil(static_cast<float>(n_output) / n_cell);
1161 }
1162 int row_sums_dims[2] = {row_sums_rows, n_cell};
1163 if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
1164 TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
1165 row_sums_size->data[0] = row_sums_dims[0];
1166 row_sums_size->data[1] = row_sums_dims[1];
1167 TF_LITE_ENSURE_OK(
1168 context, context->ResizeTensor(context, row_sums, row_sums_size));
1169 }
1170 }
1171
1172 if (is_integer) {
1173 // Integer UnidirectionalSequenceLSTM prepare function for 8x8->16.
1174 // This code path needs 5 intermediate tensors per Op.
1175 // Populate quantization parameters.
1176 PopulateQuantizedLstmParams8x8_16(context, node,
1177 &op_data->integer_lstm_param);
1178 // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell
1179 // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit
1180 // buffer with size n_batch * n_cell.
1181 //
1182 // Handle cifg case as well, which might save one buffer.
1183 for (int scratch_index = 0; scratch_index < 6; ++scratch_index) {
1184 node->temporaries->data[scratch_index] =
1185 op_data->scratch_tensor_index + scratch_index;
1186 TfLiteTensor* scratch_tensor;
1187 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, scratch_index,
1188 &scratch_tensor));
1189
1190 scratch_tensor->type = kTfLiteInt16;
1191 if (scratch_index == 4) {
1192 scratch_tensor->type = kTfLiteInt8;
1193 } else if (scratch_index == 5) {
1194 scratch_tensor->type = kTfLiteInt32;
1195 }
1196
1197 scratch_tensor->allocation_type = kTfLiteArenaRw;
1198 const int scratch_dimension[2] = {n_batch, n_cell};
1199 if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2,
1200 scratch_dimension)) {
1201 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
1202 scratch_buffer_size->data[0] = n_batch;
1203 scratch_buffer_size->data[1] = n_cell;
1204 TF_LITE_ENSURE_OK(context,
1205 context->ResizeTensor(context, scratch_tensor,
1206 scratch_buffer_size));
1207 }
1208 }
1209
1210 // Populate precomputed zp * weight.
1211 TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias(
1212 context, op_data, node));
1213 }
1214
1215 return kTfLiteOk;
1216}
1217
1218TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
1219 const auto* params =
1220 reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
1221 node->builtin_data);
1222 const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
1223 const bool use_layer_norm = op_data->use_layer_norm;
1224 const bool time_major = params->time_major;
1225 const TfLiteTensor* input;
1226 TF_LITE_ENSURE_OK(
1227 context, GetInputSafe(context, node, lstm::full::kInputTensor, &input));
1228
1229 const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(
1230 context, node, lstm::full::kInputToInputWeightsTensor);
1231 const TfLiteTensor* input_to_forget_weights;
1232 TF_LITE_ENSURE_OK(
1233 context,
1234 GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor,
1235 &input_to_forget_weights));
1236 const TfLiteTensor* input_to_cell_weights;
1237 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node,
1238 lstm::full::kInputToCellWeightsTensor,
1239 &input_to_cell_weights));
1240 const TfLiteTensor* input_to_output_weights;
1241 TF_LITE_ENSURE_OK(
1242 context,
1243 GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor,
1244 &input_to_output_weights));
1245
1246 const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(
1247 context, node, lstm::full::kRecurrentToInputWeightsTensor);
1248 const TfLiteTensor* recurrent_to_forget_weights;
1249 TF_LITE_ENSURE_OK(
1250 context,
1251 GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor,
1252 &recurrent_to_forget_weights));
1253 const TfLiteTensor* recurrent_to_cell_weights;
1254 TF_LITE_ENSURE_OK(
1255 context,
1256 GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor,
1257 &recurrent_to_cell_weights));
1258 const TfLiteTensor* recurrent_to_output_weights;
1259 TF_LITE_ENSURE_OK(
1260 context,
1261 GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor,
1262 &recurrent_to_output_weights));
1263
1264 const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor(
1265 context, node, lstm::full::kCellToInputWeightsTensor);
1266 const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor(
1267 context, node, lstm::full::kCellToForgetWeightsTensor);
1268 const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor(
1269 context, node, lstm::full::kCellToOutputWeightsTensor);
1270
1271 const TfLiteTensor* input_gate_bias =
1272 GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor);
1273 const TfLiteTensor* forget_gate_bias;
1274 TF_LITE_ENSURE_OK(
1275 context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor,
1276 &forget_gate_bias));
1277 const TfLiteTensor* cell_gate_bias;
1278 TF_LITE_ENSURE_OK(context,
1279 GetInputSafe(context, node, lstm::full::kCellGateBiasTensor,
1280 &cell_gate_bias));
1281 const TfLiteTensor* output_gate_bias;
1282 TF_LITE_ENSURE_OK(
1283 context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor,
1284 &output_gate_bias));
1285
1286 const TfLiteTensor* projection_weights = GetOptionalInputTensor(
1287 context, node, lstm::full::kProjectionWeightsTensor);
1288 const TfLiteTensor* projection_bias =
1289 GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor);
1290
1291 TfLiteTensor* output_state =
1292 GetVariableInput(context, node, lstm::full::kOutputStateTensor);
1293 TFLITE_DCHECK(output_state != nullptr);
1294 TfLiteTensor* cell_state =
1295 GetVariableInput(context, node, lstm::full::kCellStateTensor);
1296 TFLITE_DCHECK(cell_state != nullptr);
1297
1298 const TfLiteTensor* input_layer_norm_coefficients =
1299 use_layer_norm
1300 ? GetOptionalInputTensor(
1301 context, node, lstm::full::kInputLayerNormCoefficientsTensor)
1302 : nullptr;
1303 const TfLiteTensor* forget_layer_norm_coefficients =
1304 use_layer_norm ? GetInput(context, node,
1305 lstm::full::kForgetLayerNormCoefficientsTensor)
1306 : nullptr;
1307 const TfLiteTensor* cell_layer_norm_coefficients =
1308 use_layer_norm ? GetInput(context, node,
1309 lstm::full::kCellLayerNormCoefficientsTensor)
1310 : nullptr;
1311 const TfLiteTensor* output_layer_norm_coefficients =
1312 use_layer_norm ? GetInput(context, node,
1313 lstm::full::kOutputLayerNormCoefficientsTensor)
1314 : nullptr;
1315
1316 TfLiteTensor* output;
1317 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node,
1318 lstm::full::kOutputTensor, &output));
1319
1320 // Copy out the LSTM specific params so they can be passed in the function.
1321 TfLiteLSTMParams lstm_params;
1322 lstm_params.activation = params->activation;
1323 lstm_params.cell_clip = params->cell_clip;
1324 lstm_params.proj_clip = params->proj_clip;
1325 lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs;
1326
1327 switch (input_to_output_weights->type) {
1328 case kTfLiteFloat32: {
1329 // Index the scratch buffers pointers to the global scratch buffer.
1330 TfLiteTensor* scratch_buffer;
1331 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer,
1332 &scratch_buffer));
1333 return lstm_eval::EvalFloat(
1334 input, input_to_input_weights, input_to_forget_weights,
1335 input_to_cell_weights, input_to_output_weights,
1336 recurrent_to_input_weights, recurrent_to_forget_weights,
1337 recurrent_to_cell_weights, recurrent_to_output_weights,
1338 cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
1339 input_layer_norm_coefficients, forget_layer_norm_coefficients,
1340 cell_layer_norm_coefficients, output_layer_norm_coefficients,
1341 /*aux_input=*/nullptr,
1342 /*aux_input_to_input_weights=*/nullptr,
1343 /*aux_input_to_forget_weights=*/nullptr,
1344 /*aux_input_to_cell_weights=*/nullptr,
1345 /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
1346 forget_gate_bias, cell_gate_bias, output_gate_bias,
1347 projection_weights, projection_bias, &lstm_params,
1348 /*forward_sequence=*/true, time_major,
1349 /*output_offset=*/0, scratch_buffer, output_state, cell_state, output,
1350 CpuBackendContext::GetFromContext(context));
1351 }
1352 case kTfLiteUInt8:
1353 case kTfLiteInt8: {
1354 const bool is_hybrid = input->type == kTfLiteFloat32;
1355 if (is_hybrid) {
1356 // Index the scratch buffers pointers to the global scratch buffer.
1357 TfLiteTensor* scratch_buffer;
1358 TF_LITE_ENSURE_OK(
1359 context,
1360 GetTemporarySafe(context, node, kScratchBuffer, &scratch_buffer));
1361
1362 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
1363 TfLiteTensor* row_sums;
1364 TF_LITE_ENSURE_OK(context,
1365 GetTemporarySafe(context, node, kRowSums, &row_sums));
1366 const int row_sums_size = row_sums->dims->data[0];
1367 return lstm_eval::EvalHybrid(
1368 input, input_to_input_weights,
1369 /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights,
1370 /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights,
1371 /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights,
1372 /*input_to_output_weights_ledger*/ nullptr,
1373 recurrent_to_input_weights,
1374 /*recurrent_to_input_weights_ledger*/ nullptr,
1375 recurrent_to_forget_weights,
1376 /*recurrent_to_forget_weights_ledger*/ nullptr,
1377 recurrent_to_cell_weights,
1378 /*recurrent_to_cell_weights_ledger*/ nullptr,
1379 recurrent_to_output_weights,
1380 /*recurrent_to_output_weights_ledger*/ nullptr,
1381 cell_to_input_weights, cell_to_forget_weights,
1382 cell_to_output_weights, input_layer_norm_coefficients,
1383 forget_layer_norm_coefficients, cell_layer_norm_coefficients,
1384 output_layer_norm_coefficients,
1385 /*aux_input=*/nullptr,
1386 /*aux_input_to_input_weights=*/nullptr,
1387 /*aux_input_to_forget_weights=*/nullptr,
1388 /*aux_input_to_cell_weights=*/nullptr,
1389 /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
1390 forget_gate_bias, cell_gate_bias, output_gate_bias,
1391 projection_weights, /*projection_weights_ledger*/ nullptr,
1392 projection_bias, &lstm_params,
1393 /*forward_sequence=*/true, time_major,
1394 /*output_offset=*/0, scratch_buffer,
1395 GetTemporary(context, node, kInputScalingFactors),
1396 /*aux_input_sf=*/nullptr,
1397 GetTemporary(context, node, kOutputStateScalingFactors),
1398 GetTemporary(context, node, kProductScalingFactors),
1399 GetTemporary(context, node, kRecoveredCellWeights),
1400 GetTemporary(context, node, kInputQuantized),
1401 /*aux_input_quantized=*/nullptr,
1402 GetTemporary(context, node, kOutputStateQuantized),
1403 GetTemporary(context, node, kCellStateQuantized), output_state,
1404 cell_state, GetTemporary(context, node, kAccumScratch), output,
1405 GetTemporary(context, node, kInputZeroPoints),
1406 /*aux_input_zp=*/nullptr,
1407 GetTemporary(context, node, kOutputStateZeroPoints), row_sums,
1408 row_sums_size, &op_data->compute_row_sums,
1409 CpuBackendContext::GetFromContext(context));
1410 } else {
1411 TfLiteTensor* scratch0;
1412 TF_LITE_ENSURE_OK(context,
1413 GetTemporarySafe(context, node, 0, &scratch0));
1414 TfLiteTensor* scratch1;
1415 TF_LITE_ENSURE_OK(context,
1416 GetTemporarySafe(context, node, 1, &scratch1));
1417 TfLiteTensor* scratch2;
1418 TF_LITE_ENSURE_OK(context,
1419 GetTemporarySafe(context, node, 2, &scratch2));
1420 TfLiteTensor* scratch3;
1421 TF_LITE_ENSURE_OK(context,
1422 GetTemporarySafe(context, node, 3, &scratch3));
1423 TfLiteTensor* scratch4;
1424 TF_LITE_ENSURE_OK(context,
1425 GetTemporarySafe(context, node, 4, &scratch4));
1426 TfLiteTensor* scratch5;
1427 TF_LITE_ENSURE_OK(context,
1428 GetTemporarySafe(context, node, 5, &scratch5));
1429 return lstm_eval::EvalInteger8x8_16(
1430 input, input_to_input_weights, input_to_forget_weights,
1431 input_to_cell_weights, input_to_output_weights,
1432 recurrent_to_input_weights, recurrent_to_forget_weights,
1433 recurrent_to_cell_weights, recurrent_to_output_weights,
1434 cell_to_input_weights, cell_to_forget_weights,
1435 cell_to_output_weights, input_layer_norm_coefficients,
1436 forget_layer_norm_coefficients, cell_layer_norm_coefficients,
1437 output_layer_norm_coefficients, input_gate_bias, forget_gate_bias,
1438 cell_gate_bias, output_gate_bias, projection_weights,
1439 projection_bias, &lstm_params, /*forward_sequence=*/true,
1440 time_major, &op_data->integer_lstm_param, output_state, cell_state,
1441 output, scratch0, scratch1, scratch2, scratch3, scratch4, scratch5,
1442 CpuBackendContext::GetFromContext(context));
1443 }
1444 }
1445 default:
1446 TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",
1447 TfLiteTypeGetName(input_to_output_weights->type));
1448 return kTfLiteError;
1449 }
1450}
1451} // namespace unidirectional_sequence_lstm
1452
1453TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() {
1454 static TfLiteRegistration r = {unidirectional_sequence_lstm::Init,
1455 unidirectional_sequence_lstm::Free,
1456 unidirectional_sequence_lstm::Prepare,
1457 unidirectional_sequence_lstm::Eval};
1458 return &r;
1459}
1460
1461} // namespace builtin
1462} // namespace ops
1463} // namespace tflite
1464