1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/lstm_eval.h"
16
17#include <math.h>
18#include <string.h>
19
20#include <algorithm>
21#include <cstdint>
22#include <memory>
23#include <vector>
24
25#include "ruy/matrix.h" // from @ruy
26#include "ruy/mul_params.h" // from @ruy
27#include "ruy/profiler/instrumentation.h" // from @ruy
28#include "ruy/ruy.h" // from @ruy
29#include "tensorflow/lite/c/builtin_op_data.h"
30#include "tensorflow/lite/c/common.h"
31#include "tensorflow/lite/kernels/cpu_backend_context.h"
32#include "tensorflow/lite/kernels/internal/compatibility.h"
33#include "tensorflow/lite/kernels/internal/kernel_utils.h"
34#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
35#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
36#include "tensorflow/lite/kernels/internal/tensor_utils.h"
37#include "tensorflow/lite/kernels/op_macros.h"
38
39namespace tflite {
40namespace ops {
41namespace builtin {
42namespace lstm_eval {
43namespace {
44
45void MatrixBatchVectorMultiplyAccumulate(
46 const float* matrix, const float* vector, const float* result,
47 float* output, int m_rows, int m_cols, int n_batch,
48 CpuBackendContext* cpu_backend_context) {
49 tflite::FullyConnectedParams float_fc_params;
50 float_fc_params.float_activation_min = std::numeric_limits<float>::lowest();
51 float_fc_params.float_activation_max = std::numeric_limits<float>::max();
52 float_fc_params.lhs_cacheable = true;
53 float_fc_params.rhs_cacheable = false;
54
55 tflite::RuntimeShape weight_shape({m_rows, m_cols});
56 tflite::RuntimeShape input_shape({n_batch, m_cols});
57 tflite::RuntimeShape output_shape({n_batch, m_rows});
58 if (n_batch == 1) {
59 tflite::optimized_ops::FullyConnected(
60 float_fc_params, input_shape, vector, weight_shape, matrix,
61 output_shape, result, output_shape, output, cpu_backend_context);
62 } else {
63 tflite::optimized_ops::FullyConnected(
64 float_fc_params, input_shape, vector, weight_shape, matrix,
65 output_shape, nullptr, output_shape, output, cpu_backend_context);
66 for (int i = 0; i < m_rows * n_batch; ++i) {
67 output[i] += result[i];
68 }
69 }
70}
71
72void ComputeRowSums(
73 int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums,
74 int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums,
75 int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums,
76 int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums,
77 int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums,
78 int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums,
79 int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell,
80 int n_input, int n_aux_input, int n_output,
81 const int8_t* input_to_input_weights_ptr,
82 const int8_t* input_to_forget_weights_ptr,
83 const int8_t* input_to_cell_weights_ptr,
84 const int8_t* input_to_output_weights_ptr,
85 const int8_t* aux_input_to_input_weights_ptr,
86 const int8_t* aux_input_to_forget_weights_ptr,
87 const int8_t* aux_input_to_cell_weights_ptr,
88 const int8_t* aux_input_to_output_weights_ptr,
89 const int8_t* recurrent_to_input_weights_ptr,
90 const int8_t* recurrent_to_forget_weights_ptr,
91 const int8_t* recurrent_to_cell_weights_ptr,
92 const int8_t* recurrent_to_output_weights_ptr,
93 const int8_t* projection_weights_ptr, bool use_cifg,
94 const float* aux_input_ptr) {
95 // Compute the row sums for dequantization
96 if (!use_cifg) {
97 tensor_utils::ReductionSumVector(input_to_input_weights_ptr,
98 input_to_input_row_sums, n_cell, n_input);
99 }
100 tensor_utils::ReductionSumVector(input_to_forget_weights_ptr,
101 input_to_forget_row_sums, n_cell, n_input);
102 tensor_utils::ReductionSumVector(input_to_cell_weights_ptr,
103 input_to_cell_row_sums, n_cell, n_input);
104 tensor_utils::ReductionSumVector(input_to_output_weights_ptr,
105 input_to_output_row_sums, n_cell, n_input);
106
107 if (aux_input_ptr) {
108 if (!use_cifg) {
109 tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr,
110 aux_input_to_input_row_sums, n_cell,
111 n_aux_input);
112 }
113 tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr,
114 aux_input_to_forget_row_sums, n_cell,
115 n_aux_input);
116 tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr,
117 aux_input_to_cell_row_sums, n_cell,
118 n_aux_input);
119 tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr,
120 aux_input_to_output_row_sums, n_cell,
121 n_aux_input);
122 }
123 if (!use_cifg) {
124 tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr,
125 recurrent_to_input_row_sums, n_cell,
126 n_output);
127 }
128 tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr,
129 recurrent_to_forget_row_sums, n_cell,
130 n_output);
131 tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr,
132 recurrent_to_cell_row_sums, n_cell,
133 n_output);
134 tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr,
135 recurrent_to_output_row_sums, n_cell,
136 n_output);
137
138 if (projection_weights_ptr != nullptr) {
139 tensor_utils::ReductionSumVector(
140 projection_weights_ptr, projection_weights_row_sums, n_output, n_cell);
141 }
142}
143
144inline float GetTensorScale(const TfLiteTensor* tensor) {
145 return tensor == nullptr ? 1.0f : tensor->params.scale;
146}
147
148// LINT.IfChange
149// Calculates a single LSTM gate.
150//
151// Implements the following formula: (* is matrix multiply)
152// gate = activate(W_input * input + W_aux * aux_input +
153// W_peephole * cell + W_recurrent * prev_output + bias)
154// with layer norm:
155// gate = activate(W_norm * normalize(...) + bias) // not adding bias inside
156//
157// Activation is sigmoid except for the "cell" gate (configurable, usually tanh)
158//
159// Parameters:
160// Input vectors (to LSTM): | Size: | Optional?
161// input | n_input |
162// aux_input | n_aux_input | y (bidir LSTM)
163// Input vectors (persistent states):
164// output_state | n_output |
165// cell_state | n_cell |
166// 'Constant' inputs:
167// input_to_gate_weights | n_cell * n_input |
168// aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM)
169// recurrent_to_gate_weights | n_cell * n_output |
170// cell_to_gate_weights | n_cell | y (peephole)
171// gate_bias | n_cell |
172// layer_norm_coefficients | n_cell | y (layer norm)
173// Output vector:
174// gate | n_cell |
175// Scalar parameters:
176// n_batch - batch size / number of vectors
177// n_input, n_aux_input, n_output, n_cell - size of vectors.
178// activation - activation to use.
179// is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero.
180// use_layer_norm - if doing layer norm LSTM.
181inline void CalculateLstmGateFloat(
182 const float* input, const float* input_to_gate_weights,
183 const float* aux_input, const float* aux_input_to_gate_weights,
184 const float* output_state, const float* recurrent_to_gate_weights,
185 const float* cell_state, const float* cell_to_gate_weights,
186 const float* layer_norm_coefficients, const float* gate_bias,
187 const int n_batch, const int n_input, const int n_aux_input,
188 const int n_output, const int n_cell,
189 const TfLiteFusedActivation activation, float* gate,
190 const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
191 float* output, CpuBackendContext* context) {
192 const bool use_peephole = (cell_to_gate_weights != nullptr);
193 const bool use_layer_norm = (layer_norm_coefficients != nullptr);
194
195 // Initialize scratch buffers with bias for regular lstm or initialize with
196 // zero for layer norm lstm.
197 if (use_layer_norm) {
198 std::fill_n(gate, n_cell * n_batch, 0.0f);
199 } else {
200 tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
201 }
202 // For each batch and cell: compute input_weight * input.
203 // Skip if input is all zeros.
204 float* accumulation_buffer = gate;
205 if (!is_input_all_zeros) {
206 MatrixBatchVectorMultiplyAccumulate(input_to_gate_weights, input,
207 accumulation_buffer, output, n_cell,
208 n_input, n_batch, context);
209 std::swap(accumulation_buffer, output);
210 }
211 // For each batch and cell: compute aux_input_weight * aux_input.
212 // Skip if auxiliary input is not available or all zeros.
213 if (!is_aux_input_all_zeros) {
214 MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights, aux_input,
215 accumulation_buffer, output, n_cell,
216 n_aux_input, n_batch, context);
217 std::swap(accumulation_buffer, output);
218 }
219 // For each batch and cell: compute recurrent_weight * output_state.
220 MatrixBatchVectorMultiplyAccumulate(recurrent_to_gate_weights, output_state,
221 accumulation_buffer, output, n_cell,
222 n_output, n_batch, context);
223 // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
224 if (use_peephole) {
225 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
226 cell_to_gate_weights, n_cell, cell_state, n_batch, output);
227 }
228 // Do layer normalization (if layer norm LSTM)
229 if (use_layer_norm) {
230 tensor_utils::MeanStddevNormalization(output, output, n_cell, n_batch);
231 tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
232 output, n_batch, output);
233 tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, output);
234 }
235 // Apply activation
236 tensor_utils::ApplyActivationToVector(output, n_batch * n_cell, activation,
237 gate);
238}
239
240// Updates the LSTM cell state, used by both float and hybrid LSTM versions.
241//
242// Implements the following formula:
243// cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate)
244//
245// With CIFG LSTM, input gate is replaced by (1-forget_gate).
246//
247// Parameters:
248// - n_batch, n_cell: sizes of vectors
249// - cell_state: input/output vector, size n_batch*n_cell
250// - input_gate: input vector, size n_batch*n_cell.
251// - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG
252// - cell_gate: input vector, size n_batch*n_cell.
253// - use_cifg: use 1-forget_gate instead of input_gate.
254// - clip: if > 0, clip the resulting cell state to [-clip, +clip].
255void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state,
256 const float* input_gate, float* forget_gate,
257 const float* cell_gate, bool use_cifg, float clip) {
258 tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state,
259 n_batch * n_cell, cell_state);
260
261 if (use_cifg) {
262 // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as
263 // scratch, as input_gate array is not allocated in this case. (Be careful
264 // not to write to the scratch before reading the forget gate data.)
265 float* scratch = forget_gate;
266 tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
267 tensor_utils::VectorVectorCwiseProductAccumulate(
268 cell_gate, scratch, n_batch * n_cell, cell_state);
269 } else {
270 tensor_utils::VectorVectorCwiseProductAccumulate(
271 cell_gate, input_gate, n_batch * n_cell, cell_state);
272 }
273 if (clip > 0.0f) {
274 tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
275 }
276}
277
278// Calculates the output state tensor of an LSTM step.
279//
280// Implements the following formula:
281// output_no_projection = output_gate .* activate(cell_state)
282// (elementwise vector product)
283// If no projection is used:
284// output = output_state = output_no_projection
285// With projection:
286// output = output_state = clip(W*output_no_projection + bias)
287//
288// Output might not have a different 'stride' than n_batch, so we need to copy.
289//
290// Parameters:
291// - n_batch: batches: the number of distinct vectors in each array.
292// - n_cell, n_output: sizes of vectors.
293// - cell_state, output_gate: input vectors, size n_batch*n_cell.
294// - projection_weights, projection_weights_scale, projection_bias:
295// constant inputs, describing projection matrix and bias.
296// - proj_clip: if > 0, clip the output of the projection.
297// - output_state: output vector, size n_batch*n_output. Must be contigous.
298// - scratch: scratch area to store output_no_projection. Size n_batch*n_cell.
299// - projection_bias_scratch: scratch area to store projection_bias. Size
300// n_batch*n_cell.
301// - context: the CpuBackendContext for use with matrix multiplications.
302void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output,
303 const float* cell_state, const float* output_gate,
304 TfLiteFusedActivation activation,
305 const float* projection_weights,
306 const float* projection_bias,
307 const float proj_clip, float* output_state,
308 float* scratch, float* projection_bias_scratch,
309 CpuBackendContext* context) {
310 tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
311 activation, scratch);
312 tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell,
313 scratch);
314
315 const bool use_projection = (projection_weights != nullptr);
316 const bool use_projection_bias = (projection_bias != nullptr);
317
318 if (use_projection) {
319 if (use_projection_bias) {
320 tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
321 projection_bias_scratch);
322 } else {
323 std::fill_n(projection_bias_scratch, n_batch * n_output, 0.0f);
324 }
325 MatrixBatchVectorMultiplyAccumulate(projection_weights, scratch,
326 projection_bias_scratch, output_state,
327 n_output, n_cell, n_batch, context);
328 if (proj_clip > 0.0f) {
329 tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
330 }
331 } else {
332 std::copy_n(scratch, n_batch * n_output, output_state);
333 }
334}
335// LINT.ThenChange(../tools/optimize/calibration/builtin_logging_ops/lstm.cc,\
336// ../experimental/kernels/fp16/lstm_eval.cc)
337
338// Calculates a single LSTM gate, hybrid version.
339// Implements the same functionality as CalculateLstmGateFloat.
340void CalculateLstmGateHybrid(
341 // Input and weights
342 const int8_t* input, const float* input_sf, const int32_t* input_zp,
343 const int8_t* input_to_gate_weights,
344 const uint8_t* input_to_gate_weights_ledger,
345 const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums,
346 // Aux input and weights
347 const int8_t* aux_input, const float* aux_input_sf,
348 const int32_t* aux_input_zp, const int8_t* aux_input_to_gate_weights,
349 const float aux_input_to_gate_weights_scale,
350 int32_t* aux_input_to_gate_row_sums,
351 // Output state and weights
352 const int8_t* output_state, const float* output_state_sf,
353 const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights,
354 const uint8_t* recurrent_to_gate_weights_ledger,
355 const float recurrent_to_gate_weights_scale,
356 int32_t* recurrent_to_gate_row_sums,
357 // Cell state and weights (peephole LSTM)
358 const float* cell_state, const int8_t* cell_to_gate_weights,
359 const float cell_to_gate_weights_scale,
360 // Layer normalization coefficients (layer norm LSTM) + gate bias
361 const float* layer_norm_coefficients, const float* gate_bias,
362 // Array sizes
363 const int n_batch, const int n_input, const int n_aux_input,
364 const int n_output, const int n_cell,
365 const TfLiteFusedActivation activation,
366 // Output
367 float* gate,
368 // Parameters for performance optimizations
369 const bool is_input_all_zeros, const bool is_aux_input_all_zeros,
370 const bool is_output_state_all_zeros, bool* compute_row_sums,
371 CpuBackendContext* context,
372 // Scratch arrays
373 float* scratch0, // size: n_batch
374 float* scratch1, // size: n_cell, only used if peephole LSTM
375 int32_t* accum_scratch // For MatrixBatchVectorMultiplyAccumulate
376) {
377 const bool use_peephole = (cell_to_gate_weights != nullptr);
378 const bool use_layer_norm = (layer_norm_coefficients != nullptr);
379
380 // Initialize scratch buffers with bias for regular lstm or initialize with
381 // zero for layer norm lstm.
382 if (use_layer_norm) {
383 std::fill_n(gate, n_cell * n_batch, 0.0f);
384 } else {
385 tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate);
386 }
387 // For each batch and cell: compute input_weight * input.
388 // Skip if input is all zeros.
389 if (!is_input_all_zeros) {
390 if (input_to_gate_weights_ledger != nullptr) {
391 std::vector<float> scales(n_batch);
392 for (int i = 0; i < n_batch; i++) {
393 scales[i] = input_to_gate_weights_scale * input_sf[i];
394 }
395 tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
396 input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input,
397 input, scales.data(), n_batch, gate);
398
399 } else {
400 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
401 input_to_gate_weights, n_cell, n_input, input,
402 input_to_gate_weights_scale, input_sf, n_batch, gate,
403 /*per_channel_scale=*/nullptr, input_zp, accum_scratch,
404 input_to_gate_row_sums, compute_row_sums, scratch0, context);
405 }
406 }
407 // For each batch and cell: compute aux_input_weight * aux_input.
408 // Skip if auxiliary input is not available or all zeros.
409 if (!is_aux_input_all_zeros) {
410 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
411 aux_input_to_gate_weights, n_cell, n_aux_input, aux_input,
412 aux_input_to_gate_weights_scale, aux_input_sf, n_batch, gate,
413 /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch,
414 aux_input_to_gate_row_sums, compute_row_sums, scratch0, context);
415 }
416 // For each batch and cell: compute recurrent_weight * output_state.
417 // Skip if output state is all zeros.
418 if (!is_output_state_all_zeros) {
419 if (recurrent_to_gate_weights_ledger != nullptr) {
420 std::vector<float> scales(n_batch);
421 for (int i = 0; i < n_batch; i++) {
422 scales[i] = recurrent_to_gate_weights_scale * input_sf[i];
423 }
424 tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
425 recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell,
426 n_output, output_state, scales.data(), n_batch, gate);
427 } else {
428 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
429 recurrent_to_gate_weights, n_cell, n_output, output_state,
430 recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate,
431 /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch,
432 recurrent_to_gate_row_sums, compute_row_sums, scratch0, context);
433 }
434 }
435 // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM)
436 if (use_peephole) {
437 float* recovered_cell_weights = scratch1;
438 tensor_utils::VectorScalarMultiply(cell_to_gate_weights, n_cell,
439 cell_to_gate_weights_scale,
440 recovered_cell_weights);
441 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
442 recovered_cell_weights, n_cell, cell_state, n_batch, gate);
443 }
444 // Do layer normalization (if layer norm LSTM)
445 if (use_layer_norm) {
446 tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch);
447 tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell,
448 gate, n_batch, gate);
449 tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate);
450 }
451 // Apply activation
452 tensor_utils::ApplyActivationToVector(gate, n_cell * n_batch, activation,
453 gate);
454}
455
456// Calculates the output state tensor of an LSTM step. See Float version too.
457//
458// Parameters:
459// - n_batch: batches: the number of distinct vectors in each array.
460// - n_cell, n_output: sizes of vectors.
461// - cell_state, output_gate: input vectors, size n_batch*n_cell.
462// - projection_weights, projection_weights_scale, projection_bias:
463// constant inputs, describing projection matrix and bias.
464// - proj_clip: if > 0, clip the output of the projection.
465// - output_state: output vector, size n_batch*n_output. Must be contigous.
466// - asymmetric_quantize_inputs: parameter to control quantization.
467// - projection_weights_row_sums, compute_row_sums, context: Data for optimized
468// MatrixBatchVectorMultiplyAccumulate.
469// - scratch0: scratch area of size n_batch*n_cell
470// - scratch1: scratch area of size n_batch*n_cell
471// - scratch2: scratch area of size n_batch
472// - scratch3: scratch area of size n_batch
473// - scratch4: scratch area used by MatrixBatchVectorMultiplyAccumulate
474void CalculateLstmOutputHybrid(
475 int n_batch, int n_cell, int n_output, const float* cell_state,
476 const float* output_gate, TfLiteFusedActivation activation,
477 const int8_t* projection_weights, const uint8_t* projection_weights_ledger,
478 float projection_weights_scale, const float* projection_bias,
479 const float proj_clip, float* output_state, bool asymmetric_quantize_inputs,
480 int32_t* projection_weights_row_sums, bool* compute_row_sums,
481 CpuBackendContext* context, float* scratch0, int8_t* scratch1,
482 float* scratch2, int32_t* scratch3, int32_t* scratch4) {
483 tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell,
484 activation, scratch0);
485 tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0,
486 n_batch * n_cell, scratch0);
487
488 const bool use_projection = (projection_weights != nullptr);
489 const bool use_projection_bias = (projection_bias != nullptr);
490
491 if (use_projection) {
492 if (use_projection_bias) {
493 tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch,
494 output_state);
495 } else {
496 std::fill_n(output_state, n_batch * n_output, 0.0f);
497 }
498 if (!tensor_utils::IsZeroVector(scratch0, n_batch * n_cell)) {
499 // Save quantization and matmul computation for all zero output.
500 tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell, scratch1,
501 scratch2, scratch3,
502 asymmetric_quantize_inputs);
503 if (projection_weights_ledger != nullptr) {
504 std::vector<float> scales(n_batch);
505 for (int i = 0; i < n_batch; i++) {
506 scales[i] = projection_weights_scale * scratch2[i];
507 }
508 tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
509 projection_weights, projection_weights_ledger, n_output, n_cell,
510 scratch1, scales.data(), n_batch, output_state);
511 } else {
512 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
513 projection_weights, n_output, n_cell, scratch1,
514 projection_weights_scale, scratch2, n_batch, output_state,
515 /*per_channel_scale=*/nullptr, scratch3, scratch4,
516 projection_weights_row_sums, compute_row_sums, scratch2, context);
517 }
518 }
519 if (proj_clip > 0.0f) {
520 tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip);
521 }
522 } else {
523 std::copy_n(scratch0, n_batch * n_output, output_state);
524 }
525}
526
527// Calculates a single LSTM gate, int8x8_16 version.
528// Implements the same functionality as CalculateLstmGateFloat.
529void CalculateLstmGateInteger8x8_16(
530 // Input and weights
531 const int8_t* input, const int8_t* input_to_gate_weights,
532 const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a,
533 const int32_t input_to_gate_scale_b,
534 // Output state and weights
535 const int8_t* output_state, const int8_t* recurrent_to_gate_weights,
536 const int32_t* recurrent_to_gate_bias,
537 const int32_t recurrent_to_gate_scale_a,
538 const int32_t recurrent_to_gate_scale_b,
539 // Cell state and weights
540 const int16_t* cell_state, const int16_t* cell_to_gate_weights,
541 const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b,
542 // Layer normalization parameters (layer norm LSTM)
543 const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias,
544 const int32_t layer_norm_input_scale_a,
545 const int32_t layer_norm_input_scale_b,
546 const int32_t layer_norm_variance_guard,
547 // Array sizes
548 const int n_batch, const int n_input, const int n_output, const int n_cell,
549 const TfLiteFusedActivation activation,
550 // Output
551 int16_t* gate,
552 // Parameters for performance optimizations
553 CpuBackendContext* context,
554 // Scratch arrays
555 int32_t* scratch5) {
556 const bool use_peephole = (cell_to_gate_weights != nullptr);
557 const bool use_layer_norm = (layer_norm_coefficients != nullptr);
558
559 // Initialize scratch buffers with zeros. Note that unlike float and hybrid
560 // versions, bias is only used in layer normalization.
561 std::fill_n(gate, n_batch * n_cell, 0);
562 // For each batch and cell: compute input_weight * input.
563 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
564 input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a,
565 input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate,
566 context);
567 // Note: no aux_input.
568
569 // For each batch and cell: compute recurrent_weight * output_state.
570 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
571 output_state, recurrent_to_gate_bias, recurrent_to_gate_weights,
572 recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
573 n_cell, 0, scratch5, gate, context);
574 // For each batch and cell: compute cell_weight * cell_state (peephole LSTM)
575 if (use_peephole) {
576 tensor_utils::VectorBatchVectorCwiseProductAccumulate(
577 cell_to_gate_weights, n_output, cell_state, n_batch,
578 cell_to_gate_scale_a, cell_to_gate_scale_b, gate);
579 }
580 // Do layer normalization (if layer norm LSTM)
581 if (use_layer_norm) {
582 tensor_utils::ApplyLayerNorm(
583 gate, layer_norm_coefficients, layer_norm_bias,
584 layer_norm_input_scale_a, layer_norm_input_scale_b,
585 layer_norm_variance_guard, n_batch, n_cell, gate);
586 }
587 // Apply activation
588 switch (activation) {
589 case kTfLiteActSigmoid:
590 tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate);
591 break;
592 case kTfLiteActTanh:
593 tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate);
594 break;
595 default:
596 // Only Sigmoid or Tanh is used.
597 TFLITE_ASSERT_FALSE;
598 }
599}
600
601// Updates the LSTM cell state, used by both integer LSTM versions.
602// Also see UpdateLstmCellFloat.
603//
604// Parameters:
605// - n_batch, n_cell: sizes of vectors
606// - cell_state: input/output vector, size n_batch*n_cell
607// - cell_state_scale: scaling factor of cell state.
608// - input_gate: input vector, size n_batch*n_cell.
609// - forget_gate: input/scratch vector, size n_batch*n_cell, always modified.
610// - cell_gate: input vector, size n_batch*n_cell.
611// - use_cifg: use 1-forget_gate instead of input_gate.
612// - clip: if > 0, clip the resulting cell state to [-clip, +clip].
613void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state,
614 int32_t cell_state_scale, const int16_t* input_gate,
615 int16_t* forget_gate, const int16_t* cell_gate,
616 bool use_cifg, int16_t clip) {
617 // Use the forget_gate array as scratch, as input_gate array is not allocated
618 // in CIFG case. (Be careful not to write to the scratch before reading the
619 // forget gate data.)
620 int16_t* scratch = forget_gate;
621
622 tensor_utils::CwiseMul(forget_gate, cell_state, n_batch, n_cell, 15,
623 cell_state);
624 if (use_cifg) {
625 tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch);
626 tensor_utils::CwiseMul(scratch, cell_gate, n_batch, n_cell,
627 30 + cell_state_scale, scratch);
628 } else {
629 tensor_utils::CwiseMul(input_gate, cell_gate, n_batch, n_cell,
630 30 + cell_state_scale, scratch);
631 }
632 tensor_utils::CwiseAdd(cell_state, scratch, n_batch, n_cell, cell_state);
633
634 if (clip > 0) {
635 tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip);
636 }
637}
638
639// Calculates the output state tensor of an LSTM step. See Float and hybrid
640// versions as well.
641//
642// Parameters:
643// - n_batch: batches: the number of distinct vectors in each array.
644// - n_cell, n_output: sizes of vectors.
645// - cell_state, output_gate: input vectors, size n_batch*n_cell.
646// - cell_state_scale: scaling of cell_state.
647// - hidden_scale_[a|b]: effective scale of cell_state.*output_gate
648// - hidden_zp: zero_point for cell_state.*output_gate
649// - projection_weights, proj_scale_[a|b], projection_bias:
650// constant inputs, describing projection matrix and bias.
651// - output_state_zp: zero point of output_state. (Input, calibrated value.)
652// - quantized_proj_clip: if > 0, clip the output of the projection.
653// - output_state: output vector, size n_batch*n_output. Must be contigous.
654// - context: data for optimized MatrixBatchVectorMultiplyAccumulate.
655// - scratch0: scratch area of size n_batch*n_cell
656// - scratch1: scratch area of size n_batch*n_cell
657// - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate
658void CalculateLstmOutputInteger8x8_16(
659 int n_batch, int n_cell, int n_output, const int16_t* cell_state,
660 int32_t cell_state_scale, const int16_t* output_gate,
661 int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp,
662 const int8_t* projection_weights, int32_t proj_scale_a,
663 int32_t proj_scale_b, const int32_t* projection_bias,
664 int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state,
665 CpuBackendContext* context, int16_t* scratch0, int8_t* scratch1,
666 int32_t* scratch2) {
667 // Note: unlike float/hybrid, the activation is always Tanh.
668 tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch, n_cell,
669 scratch0);
670 tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a, hidden_scale_b,
671 n_batch, n_cell, hidden_zp, scratch1);
672
673 const bool use_projection = (projection_weights != nullptr);
674
675 if (use_projection) {
676 // Note: no bias like in float/hybrid
677 std::fill_n(output_state, n_batch * n_output, 0);
678 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
679 scratch1, projection_bias, projection_weights, proj_scale_a,
680 proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2,
681 output_state, context);
682 if (quantized_proj_clip > 0) {
683 tensor_utils::CwiseClipping(output_state, n_batch * n_output,
684 quantized_proj_clip);
685 }
686 } else {
687 std::copy_n(scratch1, n_batch * n_output, output_state);
688 }
689}
690
691// Calculates a single LSTM gate, int8x8_8 version.
692// Implements the same functionality as CalculateLstmGateFloat.
693void CalculateLstmGateInteger8x8_8(
694 // Inputs and weights
695 const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight,
696 const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b,
697 const int32_t input_times_weights_scale_a,
698 const int32_t input_times_weights_scale_b,
699 const int32_t input_times_weights_zp,
700 // Output state and weights
701 const int8_t* output_state, const int32_t output_state_zp,
702 const int8_t* recurrent_to_gate_weight,
703 const int32_t recurrent_to_gate_scale_a,
704 const int32_t recurrent_to_gate_scale_b,
705 const int32_t output_state_times_weights_scale_a,
706 const int32_t output_state_times_weights_scale_b,
707 const int32_t output_state_times_weights_zp,
708 // Layer normalization parameters (layer norm LSTM)
709 const int16_t* layer_norm_gate_weight,
710 const int32_t layer_norm_gate_scale_a,
711 const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias,
712 // Array sizes
713 const int n_batch, const int n_input, const int n_output, const int n_cell,
714 const TfLiteFusedActivation activation,
715 // Output
716 int16_t* gate,
717 // Scratch arrays, both sized n_batch*n_cell
718 int8_t* scratch0, int8_t* scratch1) {
719 // Multiply input * input_weights => scratch0
720 tensor_utils::MatrixBatchVectorMultiply(
721 input, input_zp, input_to_gate_weight, input_to_gate_scale_a,
722 input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0,
723 input_times_weights_zp);
724 // Multiply output_state * recurrent_weights => scratch1
725 tensor_utils::MatrixBatchVectorMultiply(
726 output_state, output_state_zp, recurrent_to_gate_weight,
727 recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output,
728 n_cell, scratch1, output_state_times_weights_zp);
729 // Add scratch0 + scratch1 => gate
730 tensor_utils::TwoGateSaturatingAdd(
731 scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp,
732 input_times_weights_scale_a, input_times_weights_scale_b,
733 output_state_times_weights_scale_a, output_state_times_weights_scale_b,
734 n_batch, n_cell, gate);
735 // Apply layer normalization.
736 tensor_utils::ApplyLayerNormFloat(
737 gate, layer_norm_gate_weight, layer_norm_gate_scale_a,
738 layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate);
739 // Apply activation.
740 switch (activation) {
741 case kTfLiteActSigmoid:
742 tensor_utils::ApplySigmoidFloat(gate, n_batch, n_cell, gate);
743 break;
744 case kTfLiteActTanh:
745 tensor_utils::ApplyTanhFloat(gate, n_batch, n_cell, -12, gate);
746 break;
747 default:
748 // Only Sigmoid or Tanh is used.
749 TFLITE_ASSERT_FALSE;
750 }
751}
752
753// Calculates the output state tensor of an LSTM step. See Float and hybrid
754// versions as well.
755//
756// Parameters:
757// - n_batch: batches: the number of distinct vectors in each array.
758// - n_cell, n_output: sizes of vectors.
759// - cell_state, output_gate: input vectors, size n_batch*n_cell.
760// - projection_weights, proj_scale_[a|b], projection_bias:
761// constant inputs, describing projection matrix and bias.
762// - output_state_zp: zero point of the output state.
763// - quantized_proj_clip: if > 0, clip the output of the projection.
764// - output_state: output vector, size n_batch*n_output. Must be contigous.
765// - scratch: scratch area of size n_batch*n_cell
766void CalculateLstmOutputInteger8x8_8(
767 int n_batch, int n_cell, int n_output, const int16_t* cell_state,
768 const int16_t* output_gate, const int8_t* projection_weights,
769 int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias,
770 int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state,
771 int16_t* scratch) {
772 // Note: unlike float/hybrid, the activation is always Tanh.
773 tensor_utils::ApplyTanhFloat(cell_state, n_batch, n_cell, -15, scratch);
774 tensor_utils::CwiseMul(output_gate, scratch, n_batch, n_cell, 15 + 15 - 15,
775 scratch);
776 // Note: no bias like in float/hybrid
777 tensor_utils::MatrixBatchVectorMultiply(
778 scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias,
779 n_batch, n_cell, n_output, output_state_zp, output_state);
780 if (quantized_proj_clip > 0) {
781 tensor_utils::CwiseClipping(output_state, n_batch * n_output,
782 quantized_proj_clip);
783 }
784}
785
786// Performs an LSTM batch inference step for input specified by input_ptr.
787// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
788// biases (*_bias_ptr), and buffers (*_scratch), along with additional
789// parameters:
790// - params: various LSTM params including activation, clipping, etc.,
791// - n_batch: size of batch,
792// - n_cell: number of cells (or units),
793// - n_input: the input size,
794// - n_aux_input: the auxiliary input size.
795// - n_output: the output size.
796// - output_batch_leading_dim: the leading dimension of the output buffer.
797// - context: the CpuBackendContext for use with matrix multiplications.
798//
799// Input of size 'n_batch * n_input':
800// input_ptr
801// Input of size 'n_batch * n_aux_input':
802// aux_input_ptr - optional (can be nullptr)
803//
804// LSTM weights:
805// Input weights of size 'n_cell * n_input':
806// input_to_input_weights - optional
807// input_to_forget_weights
808// input_to_cell_weights
809// input_to_output_weights
810// Auxiliary input weights of size 'n_cell * n_aux_input':
811// aux_input_to_input_weights - optional
812// aux_input_to_forget_weights - optional
813// aux_input_to_cell_weights - optional
814// aux_input_to_output_weights - optional
815// Recurrent weights of size 'n_cell * n_output':
816// recurrent_to_input_weights - optional
817// recurrent_to_forget_weights
818// recurrent_to_cell_weights
819// recurrent_to_input_weights
820// Peephole weights of size 'n_cell', representing diagonal matrices.
821// cell_to_input_weights - optional
822// cell_to_cell_weights - optional
823// cell_to_output_weights - optional
824// Projection weights of size 'n_output * n_cell'
825// projection_weights_ptr - optional
826// Gate biases of size 'n_cell':
827// input_gate_bias_ptr - optional
828// forget_gate_bias_ptr
829// cell_gate_bias_ptr
830// output_gate_bias_ptr
831//
832// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
833// input_layer_norm_coefficients_ptr - optional
834// forget_layer_norm_coefficients_ptr - optional
835// cell_layer_norm_coefficients_ptr - optional
836// output_layer_norm_coefficients_ptr - optional
837//
838// The pointers to the cell and output state and the output are updated.
839//
840// The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned
841// in batch_major order, and each step processes batch_size many inputs from
842// input_ptr, and updates batch_size many cell and output states.
843//
844// The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the
845// output tensor, and in most cases will be equal to n_output. It is usually not
846// when we want to store the LSTM output into a slice of the output tensor, e.g.
847// for bidirectional LSTMs with merge_outputs. In this case, the batched
848// operations cannot be used since they assume that the batched outputs are
849// contiguous, and we manually loop over the batched outputs.
850// LINT.IfChange
851inline void LstmStepFloat(
852 const float* input_ptr, const float* input_to_input_weights_ptr,
853 const float* input_to_forget_weights_ptr,
854 const float* input_to_cell_weights_ptr,
855 const float* input_to_output_weights_ptr, const float* aux_input_ptr,
856 const float* aux_input_to_input_weights_ptr,
857 const float* aux_input_to_forget_weights_ptr,
858 const float* aux_input_to_cell_weights_ptr,
859 const float* aux_input_to_output_weights_ptr,
860 const float* recurrent_to_input_weights_ptr,
861 const float* recurrent_to_forget_weights_ptr,
862 const float* recurrent_to_cell_weights_ptr,
863 const float* recurrent_to_output_weights_ptr,
864 const float* cell_to_input_weights_ptr,
865 const float* cell_to_forget_weights_ptr,
866 const float* cell_to_output_weights_ptr,
867 const float* input_layer_norm_coefficients_ptr,
868 const float* forget_layer_norm_coefficients_ptr,
869 const float* cell_layer_norm_coefficients_ptr,
870 const float* output_layer_norm_coefficients_ptr,
871 const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
872 const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
873 const float* projection_weights_ptr, const float* projection_bias_ptr,
874 const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
875 int n_aux_input, int n_output, int output_batch_leading_dim,
876 float* output_state_ptr, float* cell_state_ptr, float* scratch0,
877 float* scratch1, float* scratch2, float* scratch3, float* scratch4,
878 float* output_ptr, CpuBackendContext* context) {
879 ruy::profiler::ScopeLabel label("LstmStepFloat");
880 // Since we have already checked that weights are all there or none, we can
881 // check the existence of only one to the get the condition.
882 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
883
884 // Make named scratch buffers.
885 float* input_gate_scratch = scratch0;
886 float* forget_gate_scratch = scratch1;
887 float* cell_gate_scratch = scratch2;
888 float* output_gate_scratch = scratch3;
889 float* accumulation_scratch_buffer = scratch4;
890
891 // Check if inputs are all zeros so we can skip some computations.
892 const bool is_input_all_zeros =
893 tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
894 const bool is_aux_input_all_zeros =
895 (aux_input_ptr == nullptr ||
896 tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
897
898 if (!use_cifg) {
899 // Calculate the input gate. (If not CIFG.)
900 CalculateLstmGateFloat(input_ptr, input_to_input_weights_ptr, aux_input_ptr,
901 aux_input_to_input_weights_ptr, output_state_ptr,
902 recurrent_to_input_weights_ptr,
903
904 cell_state_ptr, cell_to_input_weights_ptr,
905 input_layer_norm_coefficients_ptr,
906 input_gate_bias_ptr, n_batch, n_input, n_aux_input,
907 n_output, n_cell,
908 /*activation=*/kTfLiteActSigmoid, input_gate_scratch,
909 is_input_all_zeros, is_aux_input_all_zeros,
910 accumulation_scratch_buffer, context);
911 }
912 // Calculate the forget gate.
913 CalculateLstmGateFloat(
914 input_ptr, input_to_forget_weights_ptr, aux_input_ptr,
915 aux_input_to_forget_weights_ptr, output_state_ptr,
916 recurrent_to_forget_weights_ptr,
917
918 cell_state_ptr, cell_to_forget_weights_ptr,
919 forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
920 n_input, n_aux_input, n_output, n_cell,
921 /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros,
922 is_aux_input_all_zeros, accumulation_scratch_buffer, context);
923 // Calculate the cell update gate.
924 CalculateLstmGateFloat(
925 input_ptr, input_to_cell_weights_ptr, aux_input_ptr,
926 aux_input_to_cell_weights_ptr, output_state_ptr,
927 recurrent_to_cell_weights_ptr,
928
929 /*cell_state=*/nullptr,
930 /*cell_to_gate_weights=*/nullptr, cell_layer_norm_coefficients_ptr,
931 cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
932 params->activation, cell_gate_scratch, is_input_all_zeros,
933 is_aux_input_all_zeros, accumulation_scratch_buffer, context);
934 // Update the cell state.
935 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
936 forget_gate_scratch, cell_gate_scratch, use_cifg,
937 params->cell_clip);
938 // Calculate output gate.
939 CalculateLstmGateFloat(
940 input_ptr, input_to_output_weights_ptr, aux_input_ptr,
941 aux_input_to_output_weights_ptr, output_state_ptr,
942 recurrent_to_output_weights_ptr,
943
944 cell_state_ptr, cell_to_output_weights_ptr,
945 output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
946 n_input, n_aux_input, n_output, n_cell,
947 /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros,
948 is_aux_input_all_zeros, accumulation_scratch_buffer, context);
949 // Update the output state.
950 CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr,
951 output_gate_scratch, params->activation,
952 projection_weights_ptr, projection_bias_ptr,
953 params->proj_clip, output_state_ptr, scratch2,
954 accumulation_scratch_buffer, context);
955 // Copy output state to the output. Note that the output's rows may not be
956 // contiguous (output_batch_leading_dim != n_output).
957 for (int b = 0; b < n_batch; b++) {
958 std::copy_n(output_state_ptr + b * n_output, n_output,
959 output_ptr + b * output_batch_leading_dim);
960 }
961}
962// LINT.ThenChange(../tools/optimize/calibration/builtin_logging_ops/lstm.cc,\
963// ../experimental/kernels/fp16/lstm_eval.cc)
964
965// Same as above but with quantized weight matrices. In detail:
966// Input of size 'n_batch * n_input':
967// input_ptr
968// Input of size 'n_batch * n_aux_input':
969// aux_input_ptr - optional (can be nullptr)
970//
971// LSTM weights:
972// Quantized input weights of size 'n_cell * n_input':
973// input_to_input_weights - optional
974// input_to_forget_weights
975// input_to_cell_weights
976// input_to_input_weights
977// Quantized auxiliary input weights of size 'n_cell * n_aux_input':
978// aux_input_to_input_weights - optional
979// aux_input_to_forget_weights - optional
980// aux_input_to_cell_weights - optional
981// aux_input_to_output_weights - optional
982// Quantized recurrent weights of size 'n_cell * n_output':
983// recurrent_to_input_weights - optional
984// recurrent_to_forget_weights
985// recurrent_to_cell_weights
986// recurrent_to_input_weights
987// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
988// cell_to_input_weights - optional
989// cell_to_cell_weights - optional
990// cell_to_output_weights - optional
991// Quantized projection weights of size 'n_output * n_cell'
992// projection_weights_ptr - optional
993// Weight scales (scalars) for each of the weights above.
994// input_to_input_weights_scale - optional
995// input_to_forget_weights_scale
996// input_to_cell_weights_scale
997// input_to_output_weights_scale
998// aux_input_to_input_weights_scale - optional
999// aux_input_to_forget_weights_scale - optional
1000// aux_input_to_cell_weights_scale - optional
1001// aux_input_to_output_weights_scale - optional
1002// recurrent_to_input_weights_scale - optional
1003// recurrent_to_forget_weights_scale
1004// recurrent_to_cell_weights_scale
1005// recurrent_to_output_weights_scale
1006// cell_to_input_weights_scale,
1007// cell_to_forget_weights_scale,
1008// cell_to_output_weights_scale,
1009// projection_weights_scale - optional
1010// Gate biases of size 'n_cell':
1011// input_gate_bias_ptr - optional
1012// forget_gate_bias_ptr
1013// cell_gate_bias_ptr
1014// output_gate_bias_ptr
1015//
1016// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
1017// input_layer_norm_coefficients_ptr - optional
1018// forget_layer_norm_coefficients_ptr - optional
1019// cell_layer_norm_coefficients_ptr - optional
1020// output_layer_norm_coefficients_ptr - optional
1021//
1022// Temporary pre-allocated storage for quantized values:
1023// quantized_input_ptr (same size as input_ptr)
1024// quantized_output_state_ptr (same size as output_state_ptr)
1025// quantized_output_scratch (same size as cell_state_ptr)
1026// Temporary pre-allocated storage for recovered values:
1027// recovered_cell_weights (same size as cell_to_*_weights)
1028//
1029// Outputs:
1030// output_state_ptr - size 'n_batch * n_output'
1031// cell_state_ptr - size 'n_batch * n_cell'
1032// output_ptr - size 'n_batch * output_batch_leading_dim'
1033inline void LstmStepHybrid(
1034 const float* input_ptr, const int8_t* input_to_input_weights_ptr,
1035 const uint8_t* input_to_input_weights_ledger_ptr,
1036 float input_to_input_weights_scale,
1037 const int8_t* input_to_forget_weights_ptr,
1038 const uint8_t* input_to_forget_weights_ledger_ptr,
1039 float input_to_forget_weights_scale,
1040 const int8_t* input_to_cell_weights_ptr,
1041 const uint8_t* input_to_cell_weights_ledger_ptr,
1042 float input_to_cell_weights_scale,
1043 const int8_t* input_to_output_weights_ptr,
1044 const uint8_t* input_to_output_weights_ledger_ptr,
1045 float input_to_output_weights_scale, const float* aux_input_ptr,
1046 const int8_t* aux_input_to_input_weights_ptr,
1047 float aux_input_to_input_weights_scale,
1048 const int8_t* aux_input_to_forget_weights_ptr,
1049 float aux_input_to_forget_weights_scale,
1050 const int8_t* aux_input_to_cell_weights_ptr,
1051 float aux_input_to_cell_weights_scale,
1052 const int8_t* aux_input_to_output_weights_ptr,
1053 float aux_input_to_output_weights_scale,
1054 const int8_t* recurrent_to_input_weights_ptr,
1055 const uint8_t* recurrent_to_input_weights_ledger_ptr,
1056 float recurrent_to_input_weights_scale,
1057 const int8_t* recurrent_to_forget_weights_ptr,
1058 const uint8_t* recurrent_to_forget_weights_ledger_ptr,
1059 float recurrent_to_forget_weights_scale,
1060 const int8_t* recurrent_to_cell_weights_ptr,
1061 const uint8_t* recurrent_to_cell_weights_ledger_ptr,
1062 float recurrent_to_cell_weights_scale,
1063 const int8_t* recurrent_to_output_weights_ptr,
1064 const uint8_t* recurrent_to_output_weights_ledger_ptr,
1065 float recurrent_to_output_weights_scale,
1066 const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
1067 const int8_t* cell_to_forget_weights_ptr,
1068 float cell_to_forget_weights_scale,
1069 const int8_t* cell_to_output_weights_ptr,
1070 float cell_to_output_weights_scale,
1071 const float* input_layer_norm_coefficients_ptr,
1072 const float* forget_layer_norm_coefficients_ptr,
1073 const float* cell_layer_norm_coefficients_ptr,
1074 const float* output_layer_norm_coefficients_ptr,
1075 const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr,
1076 const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr,
1077 const int8_t* projection_weights_ptr,
1078 const uint8_t* projection_weights_ledger_ptr,
1079 float projection_weights_scale, const float* projection_bias_ptr,
1080 const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
1081 int n_aux_input, int n_output, int output_batch_leading_dim,
1082 float* scratch0, float* scratch1, float* scratch2, float* scratch3,
1083 float* input_sf, float* aux_input_sf, float* output_state_sf,
1084 float* scaling_factors_scratch, float* recovered_cell_weights,
1085 int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr,
1086 int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch,
1087 float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr,
1088 float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp,
1089 int32_t* output_state_zp, int32_t* row_sums, int row_sums_size,
1090 bool* compute_row_sums, bool asymmetric_quantize_inputs,
1091 CpuBackendContext* context) {
1092 ruy::profiler::ScopeLabel label("LstmStepHybrid");
1093 // Since we have already checked that weights are all there or none, we
1094 // can check the existence of only one to the get the condition.
1095 const bool use_cifg = (input_to_input_weights_ptr == nullptr);
1096 // Make named scratch buffers for the different gates.
1097 float* input_gate_scratch = scratch0;
1098 float* forget_gate_scratch = scratch1;
1099 float* cell_gate_scratch = scratch2;
1100 float* output_gate_scratch = scratch3;
1101
1102 int32_t* input_to_input_row_sums = nullptr;
1103 int32_t* input_to_forget_row_sums = nullptr;
1104 int32_t* input_to_cell_row_sums = nullptr;
1105 int32_t* input_to_output_row_sums = nullptr;
1106 int32_t* aux_input_to_input_row_sums = nullptr;
1107 int32_t* aux_input_to_forget_row_sums = nullptr;
1108 int32_t* aux_input_to_cell_row_sums = nullptr;
1109 int32_t* aux_input_to_output_row_sums = nullptr;
1110 int32_t* recurrent_to_input_row_sums = nullptr;
1111 int32_t* recurrent_to_forget_row_sums = nullptr;
1112 int32_t* recurrent_to_cell_row_sums = nullptr;
1113 int32_t* recurrent_to_output_row_sums = nullptr;
1114 int32_t* projection_weights_row_sums = nullptr;
1115
1116 if (asymmetric_quantize_inputs) {
1117 int num_row_sums = use_cifg ? 6 : 8;
1118 if (aux_input_ptr != nullptr) {
1119 num_row_sums += use_cifg ? 3 : 4;
1120 }
1121 if (projection_weights_ptr != nullptr) {
1122 num_row_sums += ceil(static_cast<float>(n_output) / n_cell);
1123 }
1124 TF_LITE_ASSERT(row_sums_size == num_row_sums);
1125 input_to_input_row_sums = row_sums;
1126 input_to_forget_row_sums =
1127 use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell;
1128 input_to_cell_row_sums = input_to_forget_row_sums + n_cell;
1129 input_to_output_row_sums = input_to_cell_row_sums + n_cell;
1130 if (aux_input_ptr != nullptr) {
1131 aux_input_to_input_row_sums = input_to_output_row_sums + n_cell;
1132 aux_input_to_forget_row_sums = use_cifg
1133 ? aux_input_to_input_row_sums
1134 : aux_input_to_input_row_sums + n_cell;
1135 aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell;
1136 aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell;
1137 }
1138 recurrent_to_input_row_sums = aux_input_ptr
1139 ? aux_input_to_output_row_sums + n_cell
1140 : input_to_output_row_sums + n_cell;
1141 recurrent_to_forget_row_sums = use_cifg
1142 ? recurrent_to_input_row_sums
1143 : recurrent_to_input_row_sums + n_cell;
1144 recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell;
1145 recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell;
1146 if (projection_weights_ptr != nullptr) {
1147 projection_weights_row_sums = recurrent_to_output_row_sums + n_cell;
1148 }
1149 if (*compute_row_sums) {
1150 ComputeRowSums(
1151 input_to_input_row_sums, input_to_forget_row_sums,
1152 input_to_cell_row_sums, input_to_output_row_sums,
1153 aux_input_to_input_row_sums, aux_input_to_forget_row_sums,
1154 aux_input_to_cell_row_sums, aux_input_to_output_row_sums,
1155 recurrent_to_input_row_sums, recurrent_to_forget_row_sums,
1156 recurrent_to_cell_row_sums, recurrent_to_output_row_sums,
1157 projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input,
1158 n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr,
1159 input_to_cell_weights_ptr, input_to_output_weights_ptr,
1160 aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
1161 aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
1162 recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
1163 recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
1164 projection_weights_ptr, use_cifg, aux_input_ptr);
1165 *compute_row_sums = false;
1166 }
1167 }
1168
1169 // Check if inputs are all zeros so we can skip some computations.
1170 const bool is_input_all_zeros =
1171 tensor_utils::IsZeroVector(input_ptr, n_batch * n_input);
1172 const bool is_aux_input_all_zeros =
1173 (aux_input_ptr == nullptr ||
1174 tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input));
1175 const bool is_output_state_all_zeros =
1176 tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output);
1177 // Quantize inputs.
1178 if (!is_input_all_zeros) {
1179 tensor_utils::BatchQuantizeFloats(input_ptr, n_batch, n_input,
1180 quantized_input_ptr, input_sf, input_zp,
1181 asymmetric_quantize_inputs);
1182 }
1183 if (!is_aux_input_all_zeros) {
1184 tensor_utils::BatchQuantizeFloats(aux_input_ptr, n_batch, n_aux_input,
1185 quantized_aux_input_ptr, aux_input_sf,
1186 aux_input_zp, asymmetric_quantize_inputs);
1187 }
1188 if (!is_output_state_all_zeros) {
1189 tensor_utils::BatchQuantizeFloats(
1190 output_state_ptr, n_batch, n_output, quantized_output_state_ptr,
1191 output_state_sf, output_state_zp, asymmetric_quantize_inputs);
1192 }
1193 if (!use_cifg) {
1194 // Calculate the input gate. (If not CIFG.)
1195 CalculateLstmGateHybrid(
1196 quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr,
1197 input_to_input_weights_ledger_ptr, input_to_input_weights_scale,
1198 input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf,
1199 aux_input_zp, aux_input_to_input_weights_ptr,
1200 aux_input_to_input_weights_scale, aux_input_to_input_row_sums,
1201 quantized_output_state_ptr, output_state_sf, output_state_zp,
1202 recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr,
1203 recurrent_to_input_weights_scale, recurrent_to_input_row_sums,
1204 cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale,
1205 input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch,
1206 n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
1207 input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
1208 is_output_state_all_zeros, compute_row_sums, context,
1209 scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr);
1210 }
1211 // Calculate the forget gate.
1212 CalculateLstmGateHybrid(
1213 quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr,
1214 input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale,
1215 input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf,
1216 aux_input_zp, aux_input_to_forget_weights_ptr,
1217 aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums,
1218 quantized_output_state_ptr, output_state_sf, output_state_zp,
1219 recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr,
1220 recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums,
1221 cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
1222 forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch,
1223 n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
1224 forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
1225 is_output_state_all_zeros, compute_row_sums, context,
1226 scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr);
1227 // Calculate the cell update gate.
1228 CalculateLstmGateHybrid(
1229 quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr,
1230 input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale,
1231 input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf,
1232 aux_input_zp, aux_input_to_cell_weights_ptr,
1233 aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums,
1234 quantized_output_state_ptr, output_state_sf, output_state_zp,
1235 recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr,
1236 recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums,
1237 /*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr,
1238 /*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr,
1239 cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell,
1240 params->activation, cell_gate_scratch, is_input_all_zeros,
1241 is_aux_input_all_zeros, is_output_state_all_zeros, compute_row_sums,
1242 context, scaling_factors_scratch, recovered_cell_weights,
1243 accum_scratch_ptr);
1244 // Update the cell state.
1245 UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch,
1246 forget_gate_scratch, cell_gate_scratch, use_cifg,
1247 params->cell_clip);
1248 // Calculate the output gate.
1249 CalculateLstmGateHybrid(
1250 quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr,
1251 input_to_output_weights_ledger_ptr, input_to_output_weights_scale,
1252 input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf,
1253 aux_input_zp, aux_input_to_output_weights_ptr,
1254 aux_input_to_output_weights_scale, aux_input_to_output_row_sums,
1255 quantized_output_state_ptr, output_state_sf, output_state_zp,
1256 recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr,
1257 recurrent_to_output_weights_scale, recurrent_to_output_row_sums,
1258 cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale,
1259 output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch,
1260 n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid,
1261 output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros,
1262 is_output_state_all_zeros, compute_row_sums, context,
1263 scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr);
1264 // Update the output state.
1265 CalculateLstmOutputHybrid(
1266 n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
1267 params->activation, projection_weights_ptr, projection_weights_ledger_ptr,
1268 projection_weights_scale, projection_bias_ptr, params->proj_clip,
1269 output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums,
1270 compute_row_sums, context, scratch2, quantized_output_scratch, input_sf,
1271 input_zp, accum_scratch_ptr);
1272 // Copy output state to the output. Note that the output's rows may not be
1273 // contiguous (output_batch_leading_dim != n_output).
1274 for (int b = 0; b < n_batch; b++) {
1275 std::copy_n(output_state_ptr + b * n_output, n_output,
1276 output_ptr + b * output_batch_leading_dim);
1277 }
1278}
1279
1280// Fully quantized lstm kernel for 16 bit gate matmul output.
1281//
1282// Input tensor of size n_batch * n_input:
1283// input_ptr
1284//
1285// LSTM weights:
1286// Quantized input weights of size 'n_cell * n_input':
1287// input_to_input_weight_ptr - optional
1288// input_to_forget_weight_ptr - optional
1289// input_to_cell_weight_ptr - optional
1290// input_to_output_weight_ptr - optional
1291//
1292// Quantized recurrent weights of size 'n_cell * n_output':
1293// recurrent_to_input_weight_ptr - optional
1294// recurrent_to_forget_weights_ptr
1295// recurrent_to_cell_weights_ptr
1296// recurrent_to_input_weights_ptr
1297//
1298// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
1299// cell_to_input_weights - optional
1300// cell_to_cell_weights - optional
1301// cell_to_output_weights - optional
1302//
1303// Quantized projection weights of size 'n_output * n_cell'
1304// projection_weight_ptr - optional
1305//
1306// Weight scales (scalars) for each of the weights above.
1307// effective_input_to_input_scale_a - optional
1308// effective_input_to_input_scale_b - optional
1309// effective_input_to_forget_scale_a
1310// effective_input_to_forget_scale_b
1311// effective_input_to_cell_scale_a
1312// effective_input_to_cell_scale_b
1313// effective_input_to_output_scale_a
1314// effective_input_to_output_scale_b
1315// effective_recurrent_to_input_scale_a - optional
1316// effective_recurrent_to_input_scale_b - optional
1317// effective_recurrent_to_forget_scale_a
1318// effective_recurrent_to_forget_scale_b
1319// effective_recurrent_to_cell_scale_a
1320// effective_recurrent_to_cell_scale_b
1321// effective_recurrent_to_output_scale_a
1322// effective_recurrent_to_output_scale_b
1323// effective_proj_scale_a - optional
1324// effective_proj_scale_b - optional
1325//
1326// Gate biases of size 'n_cell':
1327// input_gate_bias_ptr - optional
1328// forget_gate_bias_ptr
1329// cell_gate_bias_ptr
1330// output_gate_bias_ptr
1331//
1332// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
1333// layer_norm_input_weight_ptr - optional
1334// layer_norm_forget_weight_ptr - optional
1335// layer_norm_cell_weight_ptr - optional
1336// layer_norm_output_weight_ptr - optional
1337//
1338// Layer norm scales of size 'n_cell'.
1339// layer_norm_input_scale_a - optional
1340// layer_norm_input_scale_b - optional
1341// layer_norm_forget_scale_a - optional
1342// layer_norm_forget_scale_b - optional
1343// layer_norm_cell_scale_a - optional
1344// layer_norm_cell_scale_b - optional
1345// layer_norm_output_scale_a - optional
1346// layer_norm_output_scale_b - optional
1347//
1348// Scalar values:
1349// quantized_cell_clip: quantized clip value for cell.
1350// quantized_proj_clip: quantized clip value for projection.
1351// cell_state_scale: the power of two scale for cell state.
1352//
1353// Zero points:
1354// output_state_zp: zero point of output state
1355// hidden_zp: zero point for hidden state.
1356//
1357// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
1358// n_batch.
1359// scratch0
1360// scratch1
1361// scratch2
1362// scratch3
1363// scratch4
1364// scratch5: this scratch buffer is created purely for optimizing the
1365// MatrixBatchVectorMultiplyAccumulate.
1366//
1367// Outputs:
1368// output_state_ptr - size 'n_batch * n_output'
1369// cell_state_ptr - size 'n_batch * n_cell'
1370// output_ptr - size 'n_batch * n_output'
1371// TODO(b/159947023): scratch0 is not used if (!cifg). Don't allocate then.
1372inline void LstmStepInteger8x8_16(
1373 const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr,
1374 int32_t effective_input_to_input_scale_a,
1375 int32_t effective_input_to_input_scale_b,
1376 const int8_t* input_to_forget_weight_ptr,
1377 int32_t effective_input_to_forget_scale_a,
1378 int32_t effective_input_to_forget_scale_b,
1379 const int8_t* input_to_cell_weight_ptr,
1380 int32_t effective_input_to_cell_scale_a,
1381 int32_t effective_input_to_cell_scale_b,
1382 const int8_t* input_to_output_weight_ptr,
1383 int32_t effective_input_to_output_scale_a,
1384 int32_t effective_input_to_output_scale_b,
1385 const int8_t* recurrent_to_input_weight_ptr,
1386 int32_t effective_recurrent_to_input_scale_a,
1387 int32_t effective_recurrent_to_input_scale_b,
1388 const int8_t* recurrent_to_forget_weight_ptr,
1389 int32_t effective_recurrent_to_forget_scale_a,
1390 int32_t effective_recurrent_to_forget_scale_b,
1391 const int8_t* recurrent_to_cell_weight_ptr,
1392 int32_t effective_recurrent_to_cell_scale_a,
1393 int32_t effective_recurrent_to_cell_scale_b,
1394 const int8_t* recurrent_to_output_weight_ptr,
1395 int32_t effective_recurrent_to_output_scale_a,
1396 int32_t effective_recurrent_to_output_scale_b,
1397 const int16_t* cell_to_input_weight_ptr,
1398 int32_t effective_cell_to_input_scale_a,
1399 int32_t effective_cell_to_input_scale_b,
1400 const int16_t* cell_to_forget_weight_ptr,
1401 int32_t effective_cell_to_forget_scale_a,
1402 int32_t effective_cell_to_forget_scale_b,
1403 const int16_t* cell_to_output_weight_ptr,
1404 int32_t effective_cell_to_output_scale_a,
1405 int32_t effective_cell_to_output_scale_b,
1406 const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
1407 int32_t effective_proj_scale_b, int32_t hidden_zp,
1408 int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b,
1409 const int16_t* layer_norm_input_weight_ptr,
1410 int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
1411 const int16_t* layer_norm_forget_weight_ptr,
1412 int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
1413 const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
1414 int32_t layer_norm_cell_scale_b,
1415 const int16_t* layer_norm_output_weight_ptr,
1416 int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
1417 const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
1418 const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
1419 int16_t quantized_cell_clip, int8_t quantized_proj_clip,
1420 int32_t cell_state_scale, int32_t input_variance_guard,
1421 int32_t forget_variance_guard, int32_t cell_variance_guard,
1422 int32_t output_variance_guard,
1423 const int32_t* input_to_forget_effective_bias,
1424 const int32_t* recurrent_to_forget_effective_bias,
1425 const int32_t* input_to_cell_effective_bias,
1426 const int32_t* recurrent_to_cell_effective_bias,
1427 const int32_t* input_to_output_effective_bias,
1428 const int32_t* recurrent_to_output_effective_bias,
1429 const int32_t* input_to_input_effective_bias,
1430 const int32_t* recurrent_to_input_effective_bias,
1431 const int32_t* projection_effective_bias, int n_batch, int n_cell,
1432 int n_input, int n_output, int8_t* output_state_ptr,
1433 int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
1434 int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3,
1435 int8_t* scratch4, int32_t* scratch5, CpuBackendContext* context) {
1436 ruy::profiler::ScopeLabel label("LstmStepInteger8x8_16");
1437 // Make named scratch buffers for the different gates.
1438 int16_t* input_gate_scratch = scratch0;
1439 int16_t* forget_gate_scratch = scratch1;
1440 int16_t* cell_gate_scratch = scratch2;
1441 int16_t* output_gate_scratch = scratch3;
1442
1443 // Since we have already checked that weights are all there or none, we
1444 // can check the existence of only one to the get the condition.
1445 const bool use_cifg = (input_to_input_weight_ptr == nullptr);
1446
1447 // Check for nullptrs.
1448 TFLITE_DCHECK(input_to_forget_effective_bias);
1449 TFLITE_DCHECK(recurrent_to_forget_effective_bias);
1450 TFLITE_DCHECK(input_to_cell_effective_bias);
1451 TFLITE_DCHECK(recurrent_to_cell_effective_bias);
1452 TFLITE_DCHECK(input_to_output_effective_bias);
1453 TFLITE_DCHECK(recurrent_to_output_effective_bias);
1454 if (!use_cifg) {
1455 TFLITE_DCHECK(input_to_input_effective_bias);
1456 TFLITE_DCHECK(recurrent_to_input_effective_bias);
1457 }
1458 const bool use_projection = (projection_weight_ptr != nullptr);
1459 if (use_projection) {
1460 TFLITE_DCHECK(projection_effective_bias);
1461 }
1462 if (!use_cifg) {
1463 // Calculate the input gate. (If not CIFG.)
1464 CalculateLstmGateInteger8x8_16(
1465 input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias,
1466 effective_input_to_input_scale_a, effective_input_to_input_scale_b,
1467 output_state_ptr, recurrent_to_input_weight_ptr,
1468 recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a,
1469 effective_recurrent_to_input_scale_b, cell_state_ptr,
1470 cell_to_input_weight_ptr, effective_cell_to_input_scale_a,
1471 effective_cell_to_input_scale_b, layer_norm_input_weight_ptr,
1472 input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b,
1473 input_variance_guard, n_batch, n_input, n_output, n_cell,
1474 kTfLiteActSigmoid, input_gate_scratch, context, scratch5);
1475 }
1476 // Calculate the forget gate.
1477 CalculateLstmGateInteger8x8_16(
1478 input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias,
1479 effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
1480 output_state_ptr, recurrent_to_forget_weight_ptr,
1481 recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a,
1482 effective_recurrent_to_forget_scale_b, cell_state_ptr,
1483 cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a,
1484 effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr,
1485 forget_gate_bias_ptr, layer_norm_forget_scale_a,
1486 layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input,
1487 n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, context,
1488 scratch5);
1489 // Calculate the cell update gate.
1490 CalculateLstmGateInteger8x8_16(
1491 input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias,
1492 effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
1493 output_state_ptr, recurrent_to_cell_weight_ptr,
1494 recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a,
1495 effective_recurrent_to_cell_scale_b, cell_state_ptr,
1496 /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0,
1497 /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr,
1498 cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b,
1499 cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh,
1500 cell_gate_scratch, context, scratch5);
1501 // Update the cell state.
1502 UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale,
1503 input_gate_scratch, forget_gate_scratch,
1504 cell_gate_scratch, use_cifg, quantized_cell_clip);
1505 // Calculate the output gate.
1506 CalculateLstmGateInteger8x8_16(
1507 input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias,
1508 effective_input_to_output_scale_a, effective_input_to_output_scale_b,
1509 output_state_ptr, recurrent_to_output_weight_ptr,
1510 recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a,
1511 effective_recurrent_to_output_scale_b, cell_state_ptr,
1512 cell_to_output_weight_ptr, effective_cell_to_output_scale_a,
1513 effective_cell_to_output_scale_b, layer_norm_output_weight_ptr,
1514 output_gate_bias_ptr, layer_norm_output_scale_a,
1515 layer_norm_output_scale_b, output_variance_guard, n_batch, n_input,
1516 n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, context,
1517 scratch5);
1518 // Update the output state.
1519 CalculateLstmOutputInteger8x8_16(
1520 n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale,
1521 output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b,
1522 hidden_zp, projection_weight_ptr, effective_proj_scale_a,
1523 effective_proj_scale_b, projection_effective_bias, output_state_zp,
1524 quantized_proj_clip, output_state_ptr, context, scratch0, scratch4,
1525 scratch5);
1526 // Copy output state to the output. Note that unlike float or hybrid, output
1527 // is always contiguous.
1528 std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
1529}
1530
1531// Fully quantized lstm kernel for 8 bit gate matmul output.
1532//
1533// Input tensor of size n_batch * n_input:
1534// input_ptr
1535//
1536// LSTM weights:
1537// Quantized input weights of size 'n_cell * n_input':
1538// input_to_input_weight_ptr - optional
1539// input_to_forget_weight_ptr - optional
1540// input_to_cell_weight_ptr - optional
1541// input_to_output_weight_ptr - optional
1542//
1543// Quantized recurrent weights of size 'n_cell * n_output':
1544// recurrent_to_input_weight_ptr - optional
1545// recurrent_to_forget_weights_ptr
1546// recurrent_to_cell_weights_ptr
1547// recurrent_to_input_weights_ptr
1548//
1549// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
1550// cell_to_input_weights - optional
1551// cell_to_cell_weights - optional
1552// cell_to_output_weights - optional
1553//
1554// Quantized projection weights of size 'n_output * n_cell'
1555// projection_weight_ptr - optional
1556//
1557// Weight scales (scalars) for each of the weights above.
1558// effective_input_to_input_scale_a - optional
1559// effective_input_to_input_scale_b - optional
1560// effective_input_to_forget_scale_a
1561// effective_input_to_forget_scale_b
1562// effective_input_to_cell_scale_a
1563// effective_input_to_cell_scale_b
1564// effective_input_to_output_scale_a
1565// effective_input_to_output_scale_b
1566// effective_recurrent_to_input_scale_a - optional
1567// effective_recurrent_to_input_scale_b - optional
1568// effective_recurrent_to_forget_scale_a
1569// effective_recurrent_to_forget_scale_b
1570// effective_recurrent_to_cell_scale_a
1571// effective_recurrent_to_cell_scale_b
1572// effective_recurrent_to_output_scale_a
1573// effective_recurrent_to_output_scale_b
1574// effective_proj_scale_a - optional
1575// effective_proj_scale_b - optional
1576//
1577// Gate biases of size 'n_cell':
1578// input_gate_bias_ptr - optional
1579// forget_gate_bias_ptr
1580// cell_gate_bias_ptr
1581// output_gate_bias_ptr
1582//
1583// Layer norm coefficients of size 'n_cell', representing diagonal matrices.
1584// layer_norm_input_weight_ptr - optional
1585// layer_norm_forget_weight_ptr - optional
1586// layer_norm_cell_weight_ptr - optional
1587// layer_norm_output_weight_ptr - optional
1588//
1589// Layer norm scales of size 'n_cell'.
1590// layer_norm_input_scale_a - optional
1591// layer_norm_input_scale_b - optional
1592// layer_norm_forget_scale_a - optional
1593// layer_norm_forget_scale_b - optional
1594// layer_norm_cell_scale_a - optional
1595// layer_norm_cell_scale_b - optional
1596// layer_norm_output_scale_a - optional
1597// layer_norm_output_scale_b - optional
1598//
1599// Scalar values:
1600// quantized_cell_clip: quantized clip value for cell.
1601// quantized_proj_clip: quantized clip value for projection.
1602// cell_state_scale: the power of two scale for cell state.
1603//
1604// Zero points:
1605// input_zp: zero point for input tensor.
1606// output_state_zp: zero point of output state.
1607// hidden_zp: zero point for hidden state.
1608//
1609// Temporary pre-allocated storage for the calculation. Each is of size n_cell *
1610// n_batch.
1611// scratch0
1612// scratch1
1613// scratch2
1614// scratch3
1615// scratch4
1616// scratch5
1617// scratch6
1618// scratch7
1619//
1620// Outputs:
1621// output_state_ptr - size 'n_batch * n_output'
1622// cell_state_ptr - size 'n_batch * n_cell'
1623// output_ptr - size 'n_batch * n_output'
1624//
1625// Can move zero point calculation into Prepare() for better perfomance.
1626// TODO(b/159947023): scratch5 is unused, remove.
1627inline void LstmStepInteger8x8_8(
1628 const int8_t* input_ptr, int32_t input_zp,
1629 const int8_t* input_to_input_weight_ptr,
1630 int32_t effective_input_to_input_scale_a,
1631 int32_t effective_input_to_input_scale_b,
1632 const int8_t* input_to_forget_weight_ptr,
1633 int32_t effective_input_to_forget_scale_a,
1634 int32_t effective_input_to_forget_scale_b,
1635 const int8_t* input_to_cell_weight_ptr,
1636 int32_t effective_input_to_cell_scale_a,
1637 int32_t effective_input_to_cell_scale_b,
1638 const int8_t* input_to_output_weight_ptr,
1639 int32_t effective_input_to_output_scale_a,
1640 int32_t effective_input_to_output_scale_b,
1641 const int8_t* recurrent_to_input_weight_ptr,
1642 int32_t effective_recurrent_to_input_scale_a,
1643 int32_t effective_recurrent_to_input_scale_b,
1644 const int8_t* recurrent_to_forget_weight_ptr,
1645 int32_t effective_recurrent_to_forget_scale_a,
1646 int32_t effective_recurrent_to_forget_scale_b,
1647 const int8_t* recurrent_to_cell_weight_ptr,
1648 int32_t effective_recurrent_to_cell_scale_a,
1649 int32_t effective_recurrent_to_cell_scale_b,
1650 const int8_t* recurrent_to_output_weight_ptr,
1651 int32_t effective_recurrent_to_output_scale_a,
1652 int32_t effective_recurrent_to_output_scale_b,
1653 const int8_t* cell_to_input_weight_ptr,
1654 int32_t effective_cell_to_input_scale_a,
1655 int32_t effective_cell_to_input_scale_b,
1656 const int8_t* cell_to_forget_weight_ptr,
1657 int32_t effective_cell_to_forget_scale_a,
1658 int32_t effective_cell_to_forget_scale_b,
1659 const int8_t* cell_to_output_weight_ptr,
1660 int32_t effective_cell_to_output_scale_a,
1661 int32_t effective_cell_to_output_scale_b,
1662 const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a,
1663 int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr,
1664 int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b,
1665 const int16_t* layer_norm_forget_weight_ptr,
1666 int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b,
1667 const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a,
1668 int32_t layer_norm_cell_scale_b,
1669 const int16_t* layer_norm_output_weight_ptr,
1670 int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b,
1671 const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr,
1672 const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr,
1673 const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params,
1674 const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b,
1675 const int32_t* intermediate_zp, int16_t quantized_cell_clip,
1676 int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input,
1677 int n_output, int output_batch_leading_dim, int8_t* output_state_ptr,
1678 int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr,
1679 int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3,
1680 int16_t* scratch4, int16_t* scratch5, int16_t* scratch6,
1681 int16_t* scratch7) {
1682 // TODO(b/159066113): scratch5 is unused, remove.
1683
1684 ruy::profiler::ScopeLabel label("LstmStepInteger8x8_8");
1685 // Make named scratch buffers for the different gates.
1686 int16_t* forget_gate_scratch = scratch2;
1687 int16_t* cell_gate_scratch = scratch3;
1688 int16_t* output_gate_scratch = scratch4;
1689 // no-CIFG is not supported here
1690
1691 // Calculate the forget gate.
1692 CalculateLstmGateInteger8x8_8(
1693 input_ptr, input_zp, input_to_forget_weight_ptr,
1694 effective_input_to_forget_scale_a, effective_input_to_forget_scale_b,
1695 intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4],
1696 output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr,
1697 effective_recurrent_to_forget_scale_a,
1698 effective_recurrent_to_forget_scale_b, intermediate_scale_a[3],
1699 intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr,
1700 layer_norm_forget_scale_a, layer_norm_forget_scale_b,
1701 forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
1702 kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1);
1703 // Calculate the cell update gate.
1704 CalculateLstmGateInteger8x8_8(
1705 input_ptr, input_zp, input_to_cell_weight_ptr,
1706 effective_input_to_cell_scale_a, effective_input_to_cell_scale_b,
1707 intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7],
1708 output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr,
1709 effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b,
1710 intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8],
1711 layer_norm_cell_weight_ptr, layer_norm_cell_scale_a,
1712 layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output,
1713 n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1);
1714 // Update the cell state.
1715 UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr,
1716 /*cell_state_scale=*/-15, /*input_gate=*/nullptr,
1717 forget_gate_scratch, cell_gate_scratch,
1718 /*use_cifg=*/true, quantized_cell_clip);
1719 // Calculate the output gate.
1720 CalculateLstmGateInteger8x8_8(
1721 input_ptr, input_zp, input_to_output_weight_ptr,
1722 effective_input_to_output_scale_a, effective_input_to_output_scale_b,
1723 intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10],
1724 output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr,
1725 effective_recurrent_to_output_scale_a,
1726 effective_recurrent_to_output_scale_b, intermediate_scale_a[11],
1727 intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr,
1728 layer_norm_output_scale_a, layer_norm_output_scale_b,
1729 output_gate_bias_ptr, n_batch, n_input, n_output, n_cell,
1730 kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1);
1731 // Update the output state.
1732 CalculateLstmOutputInteger8x8_8(
1733 n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch,
1734 projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b,
1735 projection_bias_ptr, output_state_zp, quantized_proj_clip,
1736 output_state_ptr, scratch2);
1737 // Copy output state to the output. Note that unlike float or hybrid, output
1738 // is always contigous.
1739 std::copy_n(output_state_ptr, n_batch * n_output, output_ptr);
1740}
1741
1742} // namespace
1743
1744// LINT.IfChange
1745TfLiteStatus EvalFloat(
1746 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
1747 const TfLiteTensor* input_to_forget_weights,
1748 const TfLiteTensor* input_to_cell_weights,
1749 const TfLiteTensor* input_to_output_weights,
1750 const TfLiteTensor* recurrent_to_input_weights,
1751 const TfLiteTensor* recurrent_to_forget_weights,
1752 const TfLiteTensor* recurrent_to_cell_weights,
1753 const TfLiteTensor* recurrent_to_output_weights,
1754 const TfLiteTensor* cell_to_input_weights,
1755 const TfLiteTensor* cell_to_forget_weights,
1756 const TfLiteTensor* cell_to_output_weights,
1757 const TfLiteTensor* input_layer_norm_coefficients,
1758 const TfLiteTensor* forget_layer_norm_coefficients,
1759 const TfLiteTensor* cell_layer_norm_coefficients,
1760 const TfLiteTensor* output_layer_norm_coefficients,
1761 const TfLiteTensor* aux_input,
1762 const TfLiteTensor* aux_input_to_input_weights,
1763 const TfLiteTensor* aux_input_to_forget_weights,
1764 const TfLiteTensor* aux_input_to_cell_weights,
1765 const TfLiteTensor* aux_input_to_output_weights,
1766 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
1767 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
1768 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
1769 const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
1770 int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state,
1771 TfLiteTensor* cell_state, TfLiteTensor* output,
1772 CpuBackendContext* context) {
1773 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
1774 int max_time, n_batch;
1775 if (input->dims->size == 3) {
1776 max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
1777 n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
1778 } else {
1779 max_time = 1;
1780 n_batch = input->dims->data[0];
1781 }
1782 const int n_input = input->dims->data[input->dims->size - 1];
1783 const int aux_input_size =
1784 (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
1785
1786 // n_cell and n_output will be the same size when there is no projection.
1787 const int n_cell = input_to_output_weights->dims->data[0];
1788 const int n_output = recurrent_to_output_weights->dims->data[1];
1789
1790 // Since we have already checked that weights are all there or none, we can
1791 // check the existence of only one to the get the condition.
1792 const bool use_cifg = (input_to_input_weights == nullptr);
1793
1794 // Index the scratch buffers pointers to the global scratch buffer.
1795 float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
1796 float* input_gate_scratch = nullptr;
1797 float* cell_gate_scratch = nullptr;
1798 float* forget_gate_scratch = nullptr;
1799 float* output_gate_scratch = nullptr;
1800 float* accumulation_scratch_buffer = nullptr;
1801 if (use_cifg) {
1802 cell_gate_scratch = scratch_buffer_ptr;
1803 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
1804 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
1805 accumulation_scratch_buffer = scratch_buffer_ptr + 3 * n_cell * n_batch;
1806 } else {
1807 input_gate_scratch = scratch_buffer_ptr;
1808 cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
1809 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
1810 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
1811 accumulation_scratch_buffer = scratch_buffer_ptr + 4 * n_cell * n_batch;
1812 }
1813
1814 const int output_batch_leading_dim =
1815 output->dims->data[output->dims->size - 1];
1816 if (time_major) {
1817 // Loop through the sequence.
1818 const int input_step = n_batch * n_input;
1819 const int output_step = n_batch * output_batch_leading_dim;
1820 for (int t = 0; t < max_time; t++) {
1821 // If this is the forward_sequence, step forward, otherwise step
1822 // backwards.
1823 const int t_rel = forward_sequence ? t : max_time - t - 1;
1824 const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
1825 const float* aux_input_ptr = nullptr;
1826 if (aux_input) {
1827 aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
1828 }
1829 float* output_ptr =
1830 GetTensorData<float>(output) + t_rel * output_step + output_offset;
1831
1832 LstmStepFloat(
1833 input_ptr, GetTensorData<float>(input_to_input_weights),
1834 GetTensorData<float>(input_to_forget_weights),
1835 GetTensorData<float>(input_to_cell_weights),
1836 GetTensorData<float>(input_to_output_weights), aux_input_ptr,
1837 GetTensorData<float>(aux_input_to_input_weights),
1838 GetTensorData<float>(aux_input_to_forget_weights),
1839 GetTensorData<float>(aux_input_to_cell_weights),
1840 GetTensorData<float>(aux_input_to_output_weights),
1841 GetTensorData<float>(recurrent_to_input_weights),
1842 GetTensorData<float>(recurrent_to_forget_weights),
1843 GetTensorData<float>(recurrent_to_cell_weights),
1844 GetTensorData<float>(recurrent_to_output_weights),
1845 GetTensorData<float>(cell_to_input_weights),
1846 GetTensorData<float>(cell_to_forget_weights),
1847 GetTensorData<float>(cell_to_output_weights),
1848 GetTensorData<float>(input_layer_norm_coefficients),
1849 GetTensorData<float>(forget_layer_norm_coefficients),
1850 GetTensorData<float>(cell_layer_norm_coefficients),
1851 GetTensorData<float>(output_layer_norm_coefficients),
1852 GetTensorData<float>(input_gate_bias),
1853 GetTensorData<float>(forget_gate_bias),
1854 GetTensorData<float>(cell_gate_bias),
1855 GetTensorData<float>(output_gate_bias),
1856 GetTensorData<float>(projection_weights),
1857 GetTensorData<float>(projection_bias), params, n_batch, n_cell,
1858 n_input, aux_input_size, n_output, output_batch_leading_dim,
1859 GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
1860 input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
1861 output_gate_scratch, accumulation_scratch_buffer, output_ptr,
1862 context);
1863 }
1864 } else {
1865 for (int b = 0; b < n_batch; b++) {
1866 const int input_step = n_input;
1867 const int output_step = output_batch_leading_dim;
1868 for (int t = 0; t < max_time; t++) {
1869 // If this is the forward_sequence, step forward, otherwise step
1870 // backwards.
1871 const int t_rel = forward_sequence ? t : max_time - t - 1;
1872 const int time_offset = b * max_time + t_rel;
1873 const float* input_ptr =
1874 GetTensorData<float>(input) + time_offset * input_step;
1875 const float* aux_input_ptr = nullptr;
1876 if (aux_input) {
1877 aux_input_ptr =
1878 GetTensorData<float>(aux_input) + time_offset * input_step;
1879 }
1880 float* output_ptr = GetTensorData<float>(output) +
1881 time_offset * output_step + output_offset;
1882
1883 // Offset the {output,cell}_state pointers to the right batch.
1884 float* output_state_ptr =
1885 GetTensorData<float>(output_state) + b * output_batch_leading_dim;
1886 float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
1887 // Offset the scratch pointers to the right batch.
1888 float* input_gate_scratch_ptr =
1889 input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
1890 float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
1891 float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
1892 float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
1893
1894 LstmStepFloat(
1895 input_ptr, GetTensorData<float>(input_to_input_weights),
1896 GetTensorData<float>(input_to_forget_weights),
1897 GetTensorData<float>(input_to_cell_weights),
1898 GetTensorData<float>(input_to_output_weights), aux_input_ptr,
1899 GetTensorData<float>(aux_input_to_input_weights),
1900 GetTensorData<float>(aux_input_to_forget_weights),
1901 GetTensorData<float>(aux_input_to_cell_weights),
1902 GetTensorData<float>(aux_input_to_output_weights),
1903 GetTensorData<float>(recurrent_to_input_weights),
1904 GetTensorData<float>(recurrent_to_forget_weights),
1905 GetTensorData<float>(recurrent_to_cell_weights),
1906 GetTensorData<float>(recurrent_to_output_weights),
1907 GetTensorData<float>(cell_to_input_weights),
1908 GetTensorData<float>(cell_to_forget_weights),
1909 GetTensorData<float>(cell_to_output_weights),
1910 GetTensorData<float>(input_layer_norm_coefficients),
1911 GetTensorData<float>(forget_layer_norm_coefficients),
1912 GetTensorData<float>(cell_layer_norm_coefficients),
1913 GetTensorData<float>(output_layer_norm_coefficients),
1914 GetTensorData<float>(input_gate_bias),
1915 GetTensorData<float>(forget_gate_bias),
1916 GetTensorData<float>(cell_gate_bias),
1917 GetTensorData<float>(output_gate_bias),
1918 GetTensorData<float>(projection_weights),
1919 GetTensorData<float>(projection_bias), params, /*n_batch=*/1,
1920 n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim,
1921 output_state_ptr, cell_state_ptr, input_gate_scratch_ptr,
1922 forget_gate_scratch_ptr, cell_gate_scratch_ptr,
1923 output_gate_scratch_ptr, accumulation_scratch_buffer, output_ptr,
1924 context);
1925 }
1926 }
1927 }
1928 return kTfLiteOk;
1929}
1930// LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc)
1931
1932TfLiteStatus EvalHybrid(
1933 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
1934 const TfLiteTensor* input_to_input_weights_ledger,
1935 const TfLiteTensor* input_to_forget_weights,
1936 const TfLiteTensor* input_to_forget_weights_ledger,
1937 const TfLiteTensor* input_to_cell_weights,
1938 const TfLiteTensor* input_to_cell_weights_ledger,
1939 const TfLiteTensor* input_to_output_weights,
1940 const TfLiteTensor* input_to_output_weights_ledger,
1941 const TfLiteTensor* recurrent_to_input_weights,
1942 const TfLiteTensor* recurrent_to_input_weights_ledger,
1943 const TfLiteTensor* recurrent_to_forget_weights,
1944 const TfLiteTensor* recurrent_to_forget_weights_ledger,
1945 const TfLiteTensor* recurrent_to_cell_weights,
1946 const TfLiteTensor* recurrent_to_cell_weights_ledger,
1947 const TfLiteTensor* recurrent_to_output_weights,
1948 const TfLiteTensor* recurrent_to_output_weights_ledger,
1949 const TfLiteTensor* cell_to_input_weights,
1950 const TfLiteTensor* cell_to_forget_weights,
1951 const TfLiteTensor* cell_to_output_weights,
1952 const TfLiteTensor* input_layer_norm_coefficients,
1953 const TfLiteTensor* forget_layer_norm_coefficients,
1954 const TfLiteTensor* cell_layer_norm_coefficients,
1955 const TfLiteTensor* output_layer_norm_coefficients,
1956 const TfLiteTensor* aux_input,
1957 const TfLiteTensor* aux_input_to_input_weights,
1958 const TfLiteTensor* aux_input_to_forget_weights,
1959 const TfLiteTensor* aux_input_to_cell_weights,
1960 const TfLiteTensor* aux_input_to_output_weights,
1961 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
1962 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
1963 const TfLiteTensor* projection_weights,
1964 const TfLiteTensor* projection_weights_ledger,
1965 const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params,
1966 bool forward_sequence, bool time_major, int output_offset,
1967 TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf,
1968 TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf,
1969 TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
1970 TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
1971 TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
1972 TfLiteTensor* output_state, TfLiteTensor* cell_state,
1973 TfLiteTensor* output_scratch_buffer, TfLiteTensor* output,
1974 TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp,
1975 TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size,
1976 bool* compute_row_sums, CpuBackendContext* context) {
1977 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
1978 const int n_input = input->dims->data[input->dims->size - 1];
1979 int max_time, n_batch;
1980 if (input->dims->size == 2) {
1981 max_time = 1;
1982 n_batch = input->dims->data[0];
1983 } else {
1984 max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
1985 n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
1986 }
1987 const int aux_input_size =
1988 (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
1989 // n_cell and n_output will be the same size when there is no projection.
1990 const int n_cell = input_to_output_weights->dims->data[0];
1991 const int n_output = recurrent_to_output_weights->dims->data[1];
1992
1993 // Since we have already checked that weights are all there or none, we can
1994 // check the existence of only one to get the condition.
1995 const bool use_cifg = (input_to_input_weights == nullptr);
1996
1997 float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer);
1998 float* input_gate_scratch = nullptr;
1999 float* cell_gate_scratch = nullptr;
2000 float* forget_gate_scratch = nullptr;
2001 float* output_gate_scratch = nullptr;
2002 if (use_cifg) {
2003 cell_gate_scratch = scratch_buffer_ptr;
2004 forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
2005 output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
2006 } else {
2007 input_gate_scratch = scratch_buffer_ptr;
2008 cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch;
2009 forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch;
2010 output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch;
2011 }
2012
2013 const int output_batch_leading_dim =
2014 output->dims->data[output->dims->size - 1];
2015
2016 int32_t* input_zp_ptr = nullptr;
2017 int32_t* aux_input_zp_ptr = nullptr;
2018 int32_t* output_state_zp_ptr = nullptr;
2019 int32_t* row_sums_ptr = nullptr;
2020 if (params->asymmetric_quantize_inputs) {
2021 input_zp_ptr = GetTensorData<int32_t>(input_zp);
2022 aux_input_zp_ptr = GetTensorData<int32_t>(aux_input_zp);
2023 output_state_zp_ptr = GetTensorData<int32_t>(output_state_zp);
2024 row_sums_ptr = GetTensorData<int32_t>(row_sums);
2025 }
2026
2027 if (time_major) {
2028 // Feed the sequence into the LSTM step-by-step.
2029 const int input_step = n_batch * n_input;
2030 const int output_step = n_batch * output_batch_leading_dim;
2031 for (int t = 0; t < max_time; t++) {
2032 // If this is the forward_sequence, step forward, otherwise step
2033 // backwards.
2034 const int t_rel = forward_sequence ? t : max_time - t - 1;
2035 const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step;
2036 const float* aux_input_ptr = nullptr;
2037 if (aux_input) {
2038 aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step;
2039 }
2040 float* output_ptr =
2041 GetTensorData<float>(output) + t_rel * output_step + output_offset;
2042 LstmStepHybrid(
2043 input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2044 GetTensorData<uint8_t>(input_to_input_weights_ledger),
2045 GetTensorScale(input_to_input_weights),
2046 GetTensorData<int8_t>(input_to_forget_weights),
2047 GetTensorData<uint8_t>(input_to_forget_weights_ledger),
2048 GetTensorScale(input_to_forget_weights),
2049 GetTensorData<int8_t>(input_to_cell_weights),
2050 GetTensorData<uint8_t>(input_to_cell_weights_ledger),
2051 GetTensorScale(input_to_cell_weights),
2052 GetTensorData<int8_t>(input_to_output_weights),
2053 GetTensorData<uint8_t>(input_to_output_weights_ledger),
2054 GetTensorScale(input_to_output_weights), aux_input_ptr,
2055 GetTensorData<int8_t>(aux_input_to_input_weights),
2056 GetTensorScale(aux_input_to_input_weights),
2057 GetTensorData<int8_t>(aux_input_to_forget_weights),
2058 GetTensorScale(aux_input_to_forget_weights),
2059 GetTensorData<int8_t>(aux_input_to_cell_weights),
2060 GetTensorScale(aux_input_to_cell_weights),
2061 GetTensorData<int8_t>(aux_input_to_output_weights),
2062 GetTensorScale(aux_input_to_output_weights),
2063 GetTensorData<int8_t>(recurrent_to_input_weights),
2064 GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
2065 GetTensorScale(recurrent_to_input_weights),
2066 GetTensorData<int8_t>(recurrent_to_forget_weights),
2067 GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
2068 GetTensorScale(recurrent_to_forget_weights),
2069 GetTensorData<int8_t>(recurrent_to_cell_weights),
2070 GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
2071 GetTensorScale(recurrent_to_cell_weights),
2072 GetTensorData<int8_t>(recurrent_to_output_weights),
2073 GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
2074 GetTensorScale(recurrent_to_output_weights),
2075 GetTensorData<int8_t>(cell_to_input_weights),
2076 GetTensorScale(cell_to_input_weights),
2077 GetTensorData<int8_t>(cell_to_forget_weights),
2078 GetTensorScale(cell_to_forget_weights),
2079 GetTensorData<int8_t>(cell_to_output_weights),
2080 GetTensorScale(cell_to_output_weights),
2081 GetTensorData<float>(input_layer_norm_coefficients),
2082 GetTensorData<float>(forget_layer_norm_coefficients),
2083 GetTensorData<float>(cell_layer_norm_coefficients),
2084 GetTensorData<float>(output_layer_norm_coefficients),
2085 GetTensorData<float>(input_gate_bias),
2086 GetTensorData<float>(forget_gate_bias),
2087 GetTensorData<float>(cell_gate_bias),
2088 GetTensorData<float>(output_gate_bias),
2089 GetTensorData<int8_t>(projection_weights),
2090 GetTensorData<uint8_t>(projection_weights_ledger),
2091 GetTensorScale(projection_weights),
2092 GetTensorData<float>(projection_bias), params, n_batch, n_cell,
2093 n_input, aux_input_size, n_output, output_batch_leading_dim,
2094 input_gate_scratch, forget_gate_scratch, cell_gate_scratch,
2095 output_gate_scratch, GetTensorData<float>(input_sf),
2096 GetTensorData<float>(aux_input_sf),
2097 GetTensorData<float>(output_state_sf),
2098 GetTensorData<float>(prod_scaling_factors),
2099 GetTensorData<float>(recovered_cell_weights),
2100 GetTensorData<int8_t>(input_quantized),
2101 GetTensorData<int8_t>(aux_input_quantized),
2102 GetTensorData<int8_t>(output_state_quantized),
2103 GetTensorData<int8_t>(cell_state_quantized),
2104 GetTensorData<float>(output_state), GetTensorData<float>(cell_state),
2105 GetTensorData<int32_t>(output_scratch_buffer), output_ptr,
2106 input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr, row_sums_ptr,
2107 row_sums_size, compute_row_sums, params->asymmetric_quantize_inputs,
2108 context);
2109 }
2110 } else {
2111 for (int b = 0; b < n_batch; b++) {
2112 const int input_step = n_input;
2113 const int output_step = output_batch_leading_dim;
2114 for (int t = 0; t < max_time; t++) {
2115 // If this is the forward_sequence, step forward, otherwise step
2116 // backwards.
2117 const int t_rel = forward_sequence ? t : max_time - t - 1;
2118 const int time_offset = b * max_time + t_rel;
2119 const float* input_ptr =
2120 GetTensorData<float>(input) + time_offset * input_step;
2121 const float* aux_input_ptr = nullptr;
2122 if (aux_input) {
2123 aux_input_ptr =
2124 GetTensorData<float>(aux_input) + time_offset * input_step;
2125 }
2126 float* output_ptr = GetTensorData<float>(output) +
2127 time_offset * output_step + output_offset;
2128
2129 // Offset the {output,cell}_state pointers to the right batch.
2130 float* output_state_ptr =
2131 GetTensorData<float>(output_state) + b * output_batch_leading_dim;
2132 float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell;
2133 // Offset the scratch pointers to the right batch.
2134 float* input_gate_scratch_ptr =
2135 input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr;
2136 float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell;
2137 float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell;
2138 float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell;
2139
2140 LstmStepHybrid(
2141 input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2142 GetTensorData<uint8_t>(input_to_input_weights_ledger),
2143 GetTensorScale(input_to_input_weights),
2144 GetTensorData<int8_t>(input_to_forget_weights),
2145 GetTensorData<uint8_t>(input_to_forget_weights_ledger),
2146 GetTensorScale(input_to_forget_weights),
2147 GetTensorData<int8_t>(input_to_cell_weights),
2148 GetTensorData<uint8_t>(input_to_cell_weights_ledger),
2149 GetTensorScale(input_to_cell_weights),
2150 GetTensorData<int8_t>(input_to_output_weights),
2151 GetTensorData<uint8_t>(input_to_output_weights_ledger),
2152 GetTensorScale(input_to_output_weights), aux_input_ptr,
2153 GetTensorData<int8_t>(aux_input_to_input_weights),
2154 GetTensorScale(aux_input_to_input_weights),
2155 GetTensorData<int8_t>(aux_input_to_forget_weights),
2156 GetTensorScale(aux_input_to_forget_weights),
2157 GetTensorData<int8_t>(aux_input_to_cell_weights),
2158 GetTensorScale(aux_input_to_cell_weights),
2159 GetTensorData<int8_t>(aux_input_to_output_weights),
2160 GetTensorScale(aux_input_to_output_weights),
2161 GetTensorData<int8_t>(recurrent_to_input_weights),
2162 GetTensorData<uint8_t>(recurrent_to_input_weights_ledger),
2163 GetTensorScale(recurrent_to_input_weights),
2164 GetTensorData<int8_t>(recurrent_to_forget_weights),
2165 GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger),
2166 GetTensorScale(recurrent_to_forget_weights),
2167 GetTensorData<int8_t>(recurrent_to_cell_weights),
2168 GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger),
2169 GetTensorScale(recurrent_to_cell_weights),
2170 GetTensorData<int8_t>(recurrent_to_output_weights),
2171 GetTensorData<uint8_t>(recurrent_to_output_weights_ledger),
2172 GetTensorScale(recurrent_to_output_weights),
2173 GetTensorData<int8_t>(cell_to_input_weights),
2174 GetTensorScale(cell_to_input_weights),
2175 GetTensorData<int8_t>(cell_to_forget_weights),
2176 GetTensorScale(cell_to_forget_weights),
2177 GetTensorData<int8_t>(cell_to_output_weights),
2178 GetTensorScale(cell_to_output_weights),
2179 GetTensorData<float>(input_layer_norm_coefficients),
2180 GetTensorData<float>(forget_layer_norm_coefficients),
2181 GetTensorData<float>(cell_layer_norm_coefficients),
2182 GetTensorData<float>(output_layer_norm_coefficients),
2183 GetTensorData<float>(input_gate_bias),
2184 GetTensorData<float>(forget_gate_bias),
2185 GetTensorData<float>(cell_gate_bias),
2186 GetTensorData<float>(output_gate_bias),
2187 GetTensorData<int8_t>(projection_weights),
2188 GetTensorData<uint8_t>(projection_weights_ledger),
2189 GetTensorScale(projection_weights),
2190 GetTensorData<float>(projection_bias), params,
2191 /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output,
2192 output_batch_leading_dim, input_gate_scratch_ptr,
2193 forget_gate_scratch_ptr, cell_gate_scratch_ptr,
2194 output_gate_scratch_ptr, GetTensorData<float>(input_sf),
2195 GetTensorData<float>(aux_input_sf),
2196 GetTensorData<float>(output_state_sf),
2197 GetTensorData<float>(prod_scaling_factors),
2198 GetTensorData<float>(recovered_cell_weights),
2199 GetTensorData<int8_t>(input_quantized),
2200 GetTensorData<int8_t>(aux_input_quantized),
2201 GetTensorData<int8_t>(output_state_quantized),
2202 GetTensorData<int8_t>(cell_state_quantized), output_state_ptr,
2203 cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer),
2204 output_ptr, input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr,
2205 row_sums_ptr, row_sums_size, compute_row_sums,
2206 params->asymmetric_quantize_inputs, context);
2207 }
2208 }
2209 }
2210
2211 return kTfLiteOk;
2212}
2213
2214TfLiteStatus EvalInteger8x8_16(
2215 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
2216 const TfLiteTensor* input_to_forget_weights,
2217 const TfLiteTensor* input_to_cell_weights,
2218 const TfLiteTensor* input_to_output_weights,
2219 const TfLiteTensor* recurrent_to_input_weights,
2220 const TfLiteTensor* recurrent_to_forget_weights,
2221 const TfLiteTensor* recurrent_to_cell_weights,
2222 const TfLiteTensor* recurrent_to_output_weights,
2223 const TfLiteTensor* cell_to_input_weights,
2224 const TfLiteTensor* cell_to_forget_weights,
2225 const TfLiteTensor* cell_to_output_weights,
2226 const TfLiteTensor* input_layer_norm_coefficients,
2227 const TfLiteTensor* forget_layer_norm_coefficients,
2228 const TfLiteTensor* cell_layer_norm_coefficients,
2229 const TfLiteTensor* output_layer_norm_coefficients,
2230 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
2231 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
2232 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
2233 const TfLiteLSTMParams* params, bool forward_sequence, bool time_major,
2234 const lstm_eval::IntegerLstmParameter* integer_lstm_param,
2235 TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output,
2236 TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
2237 TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
2238 CpuBackendContext* context) {
2239 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
2240 const int n_input = input->dims->data[input->dims->size - 1];
2241 int max_time, n_batch;
2242 if (input->dims->size == 2) {
2243 max_time = 1;
2244 n_batch = input->dims->data[0];
2245 } else {
2246 max_time = (time_major) ? input->dims->data[0] : input->dims->data[1];
2247 n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0];
2248 }
2249
2250 // n_cell and n_output will be the same size when there is no projection.
2251 const int n_cell = input_to_output_weights->dims->data[0];
2252 const int n_output = recurrent_to_output_weights->dims->data[1];
2253
2254 // Activation zero point
2255 int output_state_zp = output_state->params.zero_point;
2256
2257 // Get params for time/batch/sequence.
2258 const int output_batch_leading_dim =
2259 output->dims->data[output->dims->size - 1];
2260
2261 if (time_major) {
2262 const int input_step = n_batch * n_input;
2263 const int output_step = n_batch * output_batch_leading_dim;
2264 for (int t = 0; t < max_time; t++) {
2265 const int t_rel = t;
2266 int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step;
2267 const int8_t* input_ptr =
2268 GetTensorData<int8_t>(input) + t_rel * input_step;
2269 LstmStepInteger8x8_16(
2270 input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2271 integer_lstm_param->effective_input_to_input_scale_a,
2272 integer_lstm_param->effective_input_to_input_scale_b,
2273 GetTensorData<int8_t>(input_to_forget_weights),
2274 integer_lstm_param->effective_input_to_forget_scale_a,
2275 integer_lstm_param->effective_input_to_forget_scale_b,
2276 GetTensorData<int8_t>(input_to_cell_weights),
2277 integer_lstm_param->effective_input_to_cell_scale_a,
2278 integer_lstm_param->effective_input_to_cell_scale_b,
2279 GetTensorData<int8_t>(input_to_output_weights),
2280 integer_lstm_param->effective_input_to_output_scale_a,
2281 integer_lstm_param->effective_input_to_output_scale_b,
2282 GetTensorData<int8_t>(recurrent_to_input_weights),
2283 integer_lstm_param->effective_recurrent_to_input_scale_a,
2284 integer_lstm_param->effective_recurrent_to_input_scale_b,
2285 GetTensorData<int8_t>(recurrent_to_forget_weights),
2286 integer_lstm_param->effective_recurrent_to_forget_scale_a,
2287 integer_lstm_param->effective_recurrent_to_forget_scale_b,
2288 GetTensorData<int8_t>(recurrent_to_cell_weights),
2289 integer_lstm_param->effective_recurrent_to_cell_scale_a,
2290 integer_lstm_param->effective_recurrent_to_cell_scale_b,
2291 GetTensorData<int8_t>(recurrent_to_output_weights),
2292 integer_lstm_param->effective_recurrent_to_output_scale_a,
2293 integer_lstm_param->effective_recurrent_to_output_scale_b,
2294 GetTensorData<int16_t>(cell_to_input_weights),
2295 integer_lstm_param->effective_cell_to_input_scale_a,
2296 integer_lstm_param->effective_cell_to_input_scale_b,
2297 GetTensorData<int16_t>(cell_to_forget_weights),
2298 integer_lstm_param->effective_cell_to_forget_scale_a,
2299 integer_lstm_param->effective_cell_to_forget_scale_b,
2300 GetTensorData<int16_t>(cell_to_output_weights),
2301 integer_lstm_param->effective_cell_to_output_scale_a,
2302 integer_lstm_param->effective_cell_to_output_scale_b,
2303 GetTensorData<int8_t>(projection_weights),
2304 integer_lstm_param->effective_proj_scale_a,
2305 integer_lstm_param->effective_proj_scale_b,
2306 integer_lstm_param->hidden_zp,
2307 integer_lstm_param->effective_hidden_scale_a,
2308 integer_lstm_param->effective_hidden_scale_b,
2309 GetTensorData<int16_t>(input_layer_norm_coefficients),
2310 integer_lstm_param->layer_norm_input_scale_a,
2311 integer_lstm_param->layer_norm_input_scale_b,
2312 GetTensorData<int16_t>(forget_layer_norm_coefficients),
2313 integer_lstm_param->layer_norm_forget_scale_a,
2314 integer_lstm_param->layer_norm_forget_scale_b,
2315 GetTensorData<int16_t>(cell_layer_norm_coefficients),
2316 integer_lstm_param->layer_norm_cell_scale_a,
2317 integer_lstm_param->layer_norm_cell_scale_b,
2318 GetTensorData<int16_t>(output_layer_norm_coefficients),
2319 integer_lstm_param->layer_norm_output_scale_a,
2320 integer_lstm_param->layer_norm_output_scale_b,
2321 GetTensorData<int32_t>(input_gate_bias),
2322 GetTensorData<int32_t>(forget_gate_bias),
2323 GetTensorData<int32_t>(cell_gate_bias),
2324 GetTensorData<int32_t>(output_gate_bias),
2325 integer_lstm_param->quantized_cell_clip,
2326 integer_lstm_param->quantized_proj_clip,
2327 integer_lstm_param->cell_scale,
2328 integer_lstm_param->input_variance_guard,
2329 integer_lstm_param->forget_variance_guard,
2330 integer_lstm_param->cell_variance_guard,
2331 integer_lstm_param->output_variance_guard,
2332 integer_lstm_param->input_to_forget_effective_bias.get(),
2333 integer_lstm_param->recurrent_to_forget_effective_bias.get(),
2334 integer_lstm_param->input_to_cell_effective_bias.get(),
2335 integer_lstm_param->recurrent_to_cell_effective_bias.get(),
2336 integer_lstm_param->input_to_output_effective_bias.get(),
2337 integer_lstm_param->recurrent_to_output_effective_bias.get(),
2338 integer_lstm_param->input_to_input_effective_bias.get(),
2339 integer_lstm_param->recurrent_to_input_effective_bias.get(),
2340 integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell,
2341 n_input, n_output, GetTensorData<int8_t>(output_state),
2342 output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr,
2343 GetTensorData<int16_t>(scratch0), GetTensorData<int16_t>(scratch1),
2344 GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
2345 GetTensorData<int8_t>(scratch4), GetTensorData<int32_t>(scratch5),
2346 context);
2347 }
2348 } else {
2349 for (int b = 0; b < n_batch; b++) {
2350 const int input_step = n_input;
2351 const int output_step = output_batch_leading_dim;
2352 for (int t = 0; t < max_time; t++) {
2353 // If this is the forward_sequence, step forward, otherwise step
2354 // backwards.
2355 const int t_rel = forward_sequence ? t : max_time - t - 1;
2356 const int time_offset = b * max_time + t_rel;
2357 const int8_t* input_ptr =
2358 GetTensorData<int8_t>(input) + time_offset * input_step;
2359 int8_t* output_ptr =
2360 GetTensorData<int8_t>(output) + time_offset * output_step;
2361
2362 // Offset the {output,cell}_state pointers to the right batch.
2363 int8_t* output_state_ptr =
2364 GetTensorData<int8_t>(output_state) + b * output_batch_leading_dim;
2365 int16_t* cell_state_ptr =
2366 GetTensorData<int16_t>(cell_state) + b * n_cell;
2367
2368 LstmStepInteger8x8_16(
2369 input_ptr, GetTensorData<int8_t>(input_to_input_weights),
2370 integer_lstm_param->effective_input_to_input_scale_a,
2371 integer_lstm_param->effective_input_to_input_scale_b,
2372 GetTensorData<int8_t>(input_to_forget_weights),
2373 integer_lstm_param->effective_input_to_forget_scale_a,
2374 integer_lstm_param->effective_input_to_forget_scale_b,
2375 GetTensorData<int8_t>(input_to_cell_weights),
2376 integer_lstm_param->effective_input_to_cell_scale_a,
2377 integer_lstm_param->effective_input_to_cell_scale_b,
2378 GetTensorData<int8_t>(input_to_output_weights),
2379 integer_lstm_param->effective_input_to_output_scale_a,
2380 integer_lstm_param->effective_input_to_output_scale_b,
2381 GetTensorData<int8_t>(recurrent_to_input_weights),
2382 integer_lstm_param->effective_recurrent_to_input_scale_a,
2383 integer_lstm_param->effective_recurrent_to_input_scale_b,
2384 GetTensorData<int8_t>(recurrent_to_forget_weights),
2385 integer_lstm_param->effective_recurrent_to_forget_scale_a,
2386 integer_lstm_param->effective_recurrent_to_forget_scale_b,
2387 GetTensorData<int8_t>(recurrent_to_cell_weights),
2388 integer_lstm_param->effective_recurrent_to_cell_scale_a,
2389 integer_lstm_param->effective_recurrent_to_cell_scale_b,
2390 GetTensorData<int8_t>(recurrent_to_output_weights),
2391 integer_lstm_param->effective_recurrent_to_output_scale_a,
2392 integer_lstm_param->effective_recurrent_to_output_scale_b,
2393 GetTensorData<int16_t>(cell_to_input_weights),
2394 integer_lstm_param->effective_cell_to_input_scale_a,
2395 integer_lstm_param->effective_cell_to_input_scale_b,
2396 GetTensorData<int16_t>(cell_to_forget_weights),
2397 integer_lstm_param->effective_cell_to_forget_scale_a,
2398 integer_lstm_param->effective_cell_to_forget_scale_b,
2399 GetTensorData<int16_t>(cell_to_output_weights),
2400 integer_lstm_param->effective_cell_to_output_scale_a,
2401 integer_lstm_param->effective_cell_to_output_scale_b,
2402 GetTensorData<int8_t>(projection_weights),
2403 integer_lstm_param->effective_proj_scale_a,
2404 integer_lstm_param->effective_proj_scale_b,
2405 integer_lstm_param->hidden_zp,
2406 integer_lstm_param->effective_hidden_scale_a,
2407 integer_lstm_param->effective_hidden_scale_b,
2408 GetTensorData<int16_t>(input_layer_norm_coefficients),
2409 integer_lstm_param->layer_norm_input_scale_a,
2410 integer_lstm_param->layer_norm_input_scale_b,
2411 GetTensorData<int16_t>(forget_layer_norm_coefficients),
2412 integer_lstm_param->layer_norm_forget_scale_a,
2413 integer_lstm_param->layer_norm_forget_scale_b,
2414 GetTensorData<int16_t>(cell_layer_norm_coefficients),
2415 integer_lstm_param->layer_norm_cell_scale_a,
2416 integer_lstm_param->layer_norm_cell_scale_b,
2417 GetTensorData<int16_t>(output_layer_norm_coefficients),
2418 integer_lstm_param->layer_norm_output_scale_a,
2419 integer_lstm_param->layer_norm_output_scale_b,
2420 GetTensorData<int32_t>(input_gate_bias),
2421 GetTensorData<int32_t>(forget_gate_bias),
2422 GetTensorData<int32_t>(cell_gate_bias),
2423 GetTensorData<int32_t>(output_gate_bias),
2424 integer_lstm_param->quantized_cell_clip,
2425 integer_lstm_param->quantized_proj_clip,
2426 integer_lstm_param->cell_scale,
2427 integer_lstm_param->input_variance_guard,
2428 integer_lstm_param->forget_variance_guard,
2429 integer_lstm_param->cell_variance_guard,
2430 integer_lstm_param->output_variance_guard,
2431 integer_lstm_param->input_to_forget_effective_bias.get(),
2432 integer_lstm_param->recurrent_to_forget_effective_bias.get(),
2433 integer_lstm_param->input_to_cell_effective_bias.get(),
2434 integer_lstm_param->recurrent_to_cell_effective_bias.get(),
2435 integer_lstm_param->input_to_output_effective_bias.get(),
2436 integer_lstm_param->recurrent_to_output_effective_bias.get(),
2437 integer_lstm_param->input_to_input_effective_bias.get(),
2438 integer_lstm_param->recurrent_to_input_effective_bias.get(),
2439 integer_lstm_param->projection_effective_bias.get(), /*n_batch=*/1,
2440 n_cell, n_input, n_output, output_state_ptr, output_state_zp,
2441 cell_state_ptr, output_ptr, GetTensorData<int16_t>(scratch0),
2442 GetTensorData<int16_t>(scratch1), GetTensorData<int16_t>(scratch2),
2443 GetTensorData<int16_t>(scratch3), GetTensorData<int8_t>(scratch4),
2444 GetTensorData<int32_t>(scratch5), context);
2445 }
2446 }
2447 }
2448
2449 return kTfLiteOk;
2450}
2451
2452TfLiteStatus EvalInteger8x8_8(
2453 const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
2454 const TfLiteTensor* input_to_forget_weights,
2455 const TfLiteTensor* input_to_cell_weights,
2456 const TfLiteTensor* input_to_output_weights,
2457 const TfLiteTensor* recurrent_to_input_weights,
2458 const TfLiteTensor* recurrent_to_forget_weights,
2459 const TfLiteTensor* recurrent_to_cell_weights,
2460 const TfLiteTensor* recurrent_to_output_weights,
2461 const TfLiteTensor* cell_to_input_weights,
2462 const TfLiteTensor* cell_to_forget_weights,
2463 const TfLiteTensor* cell_to_output_weights,
2464 const TfLiteTensor* input_layer_norm_coefficients,
2465 const TfLiteTensor* forget_layer_norm_coefficients,
2466 const TfLiteTensor* cell_layer_norm_coefficients,
2467 const TfLiteTensor* output_layer_norm_coefficients,
2468 const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
2469 const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias,
2470 const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
2471 const TfLiteLSTMParams* params, TfLiteTensor* output_state,
2472 TfLiteTensor* cell_state, TfLiteTensor* output,
2473 const lstm_eval::IntegerLstmParameter* integer_lstm_param,
2474 TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2,
2475 TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5,
2476 TfLiteTensor* scratch6, TfLiteTensor* scratch7) {
2477 TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
2478 const int n_input = input->dims->data[input->dims->size - 1];
2479 int max_time, n_batch;
2480 if (input->dims->size == 2) {
2481 max_time = 1;
2482 n_batch = input->dims->data[0];
2483 } else {
2484 max_time = input->dims->data[0];
2485 n_batch = input->dims->data[1];
2486 }
2487
2488 // n_cell and n_output will be the same size when there is no projection.
2489 const int n_cell = input_to_output_weights->dims->data[0];
2490 const int n_output = recurrent_to_output_weights->dims->data[1];
2491
2492 const int32_t input_zp = input->params.zero_point;
2493 const int32_t output_state_zp = output_state->params.zero_point;
2494
2495 // Get params for time/batch/sequence.
2496 const int output_batch_leading_dim =
2497 output->dims->data[output->dims->size - 1];
2498 const int input_step = n_batch * n_input;
2499 const int output_step = n_batch * output_batch_leading_dim;
2500
2501 for (int t = 0; t < max_time; t++) {
2502 const int t_rel = t;
2503 int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step;
2504 // Input can be int8 asymmetric or int16 symmetric.
2505 const int8_t* input_ptr = GetTensorData<int8_t>(input) + t_rel * input_step;
2506 lstm_eval::LstmStepInteger8x8_8(
2507 input_ptr, input_zp,
2508
2509 GetTensorData<int8_t>(input_to_input_weights),
2510 integer_lstm_param->effective_input_to_input_scale_a,
2511 integer_lstm_param->effective_input_to_input_scale_b,
2512
2513 GetTensorData<int8_t>(input_to_forget_weights),
2514 integer_lstm_param->effective_input_to_forget_scale_a,
2515 integer_lstm_param->effective_input_to_forget_scale_b,
2516
2517 GetTensorData<int8_t>(input_to_cell_weights),
2518 integer_lstm_param->effective_input_to_cell_scale_a,
2519 integer_lstm_param->effective_input_to_cell_scale_b,
2520
2521 GetTensorData<int8_t>(input_to_output_weights),
2522 integer_lstm_param->effective_input_to_output_scale_a,
2523 integer_lstm_param->effective_input_to_output_scale_b,
2524
2525 GetTensorData<int8_t>(recurrent_to_input_weights),
2526 integer_lstm_param->effective_recurrent_to_input_scale_a,
2527 integer_lstm_param->effective_recurrent_to_input_scale_b,
2528
2529 GetTensorData<int8_t>(recurrent_to_forget_weights),
2530 integer_lstm_param->effective_recurrent_to_forget_scale_a,
2531 integer_lstm_param->effective_recurrent_to_forget_scale_b,
2532
2533 GetTensorData<int8_t>(recurrent_to_cell_weights),
2534 integer_lstm_param->effective_recurrent_to_cell_scale_a,
2535 integer_lstm_param->effective_recurrent_to_cell_scale_b,
2536
2537 GetTensorData<int8_t>(recurrent_to_output_weights),
2538 integer_lstm_param->effective_recurrent_to_output_scale_a,
2539 integer_lstm_param->effective_recurrent_to_output_scale_b,
2540
2541 GetTensorData<int8_t>(cell_to_input_weights),
2542 integer_lstm_param->effective_cell_to_input_scale_a,
2543 integer_lstm_param->effective_cell_to_input_scale_b,
2544
2545 GetTensorData<int8_t>(cell_to_forget_weights),
2546 integer_lstm_param->effective_cell_to_forget_scale_a,
2547 integer_lstm_param->effective_cell_to_forget_scale_b,
2548
2549 GetTensorData<int8_t>(cell_to_output_weights),
2550 integer_lstm_param->effective_cell_to_output_scale_a,
2551 integer_lstm_param->effective_cell_to_output_scale_b,
2552
2553 GetTensorData<int8_t>(projection_weights),
2554 integer_lstm_param->effective_proj_scale_a,
2555 integer_lstm_param->effective_proj_scale_b,
2556
2557 GetTensorData<int16_t>(input_layer_norm_coefficients),
2558 integer_lstm_param->layer_norm_input_scale_a,
2559 integer_lstm_param->layer_norm_input_scale_b,
2560
2561 GetTensorData<int16_t>(forget_layer_norm_coefficients),
2562 integer_lstm_param->layer_norm_forget_scale_a,
2563 integer_lstm_param->layer_norm_forget_scale_b,
2564
2565 GetTensorData<int16_t>(cell_layer_norm_coefficients),
2566 integer_lstm_param->layer_norm_cell_scale_a,
2567 integer_lstm_param->layer_norm_cell_scale_b,
2568
2569 GetTensorData<int16_t>(output_layer_norm_coefficients),
2570 integer_lstm_param->layer_norm_output_scale_a,
2571 integer_lstm_param->layer_norm_output_scale_b,
2572
2573 GetTensorData<int32_t>(input_gate_bias),
2574 GetTensorData<int32_t>(forget_gate_bias),
2575 GetTensorData<int32_t>(cell_gate_bias),
2576 GetTensorData<int32_t>(output_gate_bias),
2577 GetTensorData<int32_t>(projection_bias),
2578
2579 params, integer_lstm_param->intermediate_scale_a,
2580 integer_lstm_param->intermediate_scale_b,
2581 integer_lstm_param->intermediate_zp,
2582 integer_lstm_param->quantized_cell_clip,
2583 integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input,
2584 n_output, output_batch_leading_dim, GetTensorData<int8_t>(output_state),
2585 output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr,
2586 GetTensorData<int8_t>(scratch0), GetTensorData<int8_t>(scratch1),
2587 GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3),
2588 GetTensorData<int16_t>(scratch4), GetTensorData<int16_t>(scratch5),
2589 GetTensorData<int16_t>(scratch6), GetTensorData<int16_t>(scratch7));
2590 }
2591
2592 return kTfLiteOk;
2593}
2594
2595} // namespace lstm_eval
2596} // namespace builtin
2597} // namespace ops
2598} // namespace tflite
2599