1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/lstm_eval.h" |
16 | |
17 | #include <math.h> |
18 | #include <string.h> |
19 | |
20 | #include <algorithm> |
21 | #include <cstdint> |
22 | #include <memory> |
23 | #include <vector> |
24 | |
25 | #include "ruy/matrix.h" // from @ruy |
26 | #include "ruy/mul_params.h" // from @ruy |
27 | #include "ruy/profiler/instrumentation.h" // from @ruy |
28 | #include "ruy/ruy.h" // from @ruy |
29 | #include "tensorflow/lite/c/builtin_op_data.h" |
30 | #include "tensorflow/lite/c/common.h" |
31 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
32 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
33 | #include "tensorflow/lite/kernels/internal/kernel_utils.h" |
34 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
35 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
36 | #include "tensorflow/lite/kernels/internal/tensor_utils.h" |
37 | #include "tensorflow/lite/kernels/op_macros.h" |
38 | |
39 | namespace tflite { |
40 | namespace ops { |
41 | namespace builtin { |
42 | namespace lstm_eval { |
43 | namespace { |
44 | |
45 | void MatrixBatchVectorMultiplyAccumulate( |
46 | const float* matrix, const float* vector, const float* result, |
47 | float* output, int m_rows, int m_cols, int n_batch, |
48 | CpuBackendContext* cpu_backend_context) { |
49 | tflite::FullyConnectedParams float_fc_params; |
50 | float_fc_params.float_activation_min = std::numeric_limits<float>::lowest(); |
51 | float_fc_params.float_activation_max = std::numeric_limits<float>::max(); |
52 | float_fc_params.lhs_cacheable = true; |
53 | float_fc_params.rhs_cacheable = false; |
54 | |
55 | tflite::RuntimeShape weight_shape({m_rows, m_cols}); |
56 | tflite::RuntimeShape input_shape({n_batch, m_cols}); |
57 | tflite::RuntimeShape output_shape({n_batch, m_rows}); |
58 | if (n_batch == 1) { |
59 | tflite::optimized_ops::FullyConnected( |
60 | float_fc_params, input_shape, vector, weight_shape, matrix, |
61 | output_shape, result, output_shape, output, cpu_backend_context); |
62 | } else { |
63 | tflite::optimized_ops::FullyConnected( |
64 | float_fc_params, input_shape, vector, weight_shape, matrix, |
65 | output_shape, nullptr, output_shape, output, cpu_backend_context); |
66 | for (int i = 0; i < m_rows * n_batch; ++i) { |
67 | output[i] += result[i]; |
68 | } |
69 | } |
70 | } |
71 | |
72 | void ComputeRowSums( |
73 | int32_t* input_to_input_row_sums, int32_t* input_to_forget_row_sums, |
74 | int32_t* input_to_cell_row_sums, int32_t* input_to_output_row_sums, |
75 | int32_t* aux_input_to_input_row_sums, int32_t* aux_input_to_forget_row_sums, |
76 | int32_t* aux_input_to_cell_row_sums, int32_t* aux_input_to_output_row_sums, |
77 | int32_t* recurrent_to_input_row_sums, int32_t* recurrent_to_forget_row_sums, |
78 | int32_t* recurrent_to_cell_row_sums, int32_t* recurrent_to_output_row_sums, |
79 | int32_t* projection_weights_row_sums, int32_t* row_sums, int n_cell, |
80 | int n_input, int n_aux_input, int n_output, |
81 | const int8_t* input_to_input_weights_ptr, |
82 | const int8_t* input_to_forget_weights_ptr, |
83 | const int8_t* input_to_cell_weights_ptr, |
84 | const int8_t* input_to_output_weights_ptr, |
85 | const int8_t* aux_input_to_input_weights_ptr, |
86 | const int8_t* aux_input_to_forget_weights_ptr, |
87 | const int8_t* aux_input_to_cell_weights_ptr, |
88 | const int8_t* aux_input_to_output_weights_ptr, |
89 | const int8_t* recurrent_to_input_weights_ptr, |
90 | const int8_t* recurrent_to_forget_weights_ptr, |
91 | const int8_t* recurrent_to_cell_weights_ptr, |
92 | const int8_t* recurrent_to_output_weights_ptr, |
93 | const int8_t* projection_weights_ptr, bool use_cifg, |
94 | const float* aux_input_ptr) { |
95 | // Compute the row sums for dequantization |
96 | if (!use_cifg) { |
97 | tensor_utils::ReductionSumVector(input_to_input_weights_ptr, |
98 | input_to_input_row_sums, n_cell, n_input); |
99 | } |
100 | tensor_utils::ReductionSumVector(input_to_forget_weights_ptr, |
101 | input_to_forget_row_sums, n_cell, n_input); |
102 | tensor_utils::ReductionSumVector(input_to_cell_weights_ptr, |
103 | input_to_cell_row_sums, n_cell, n_input); |
104 | tensor_utils::ReductionSumVector(input_to_output_weights_ptr, |
105 | input_to_output_row_sums, n_cell, n_input); |
106 | |
107 | if (aux_input_ptr) { |
108 | if (!use_cifg) { |
109 | tensor_utils::ReductionSumVector(aux_input_to_input_weights_ptr, |
110 | aux_input_to_input_row_sums, n_cell, |
111 | n_aux_input); |
112 | } |
113 | tensor_utils::ReductionSumVector(aux_input_to_forget_weights_ptr, |
114 | aux_input_to_forget_row_sums, n_cell, |
115 | n_aux_input); |
116 | tensor_utils::ReductionSumVector(aux_input_to_cell_weights_ptr, |
117 | aux_input_to_cell_row_sums, n_cell, |
118 | n_aux_input); |
119 | tensor_utils::ReductionSumVector(aux_input_to_output_weights_ptr, |
120 | aux_input_to_output_row_sums, n_cell, |
121 | n_aux_input); |
122 | } |
123 | if (!use_cifg) { |
124 | tensor_utils::ReductionSumVector(recurrent_to_input_weights_ptr, |
125 | recurrent_to_input_row_sums, n_cell, |
126 | n_output); |
127 | } |
128 | tensor_utils::ReductionSumVector(recurrent_to_forget_weights_ptr, |
129 | recurrent_to_forget_row_sums, n_cell, |
130 | n_output); |
131 | tensor_utils::ReductionSumVector(recurrent_to_cell_weights_ptr, |
132 | recurrent_to_cell_row_sums, n_cell, |
133 | n_output); |
134 | tensor_utils::ReductionSumVector(recurrent_to_output_weights_ptr, |
135 | recurrent_to_output_row_sums, n_cell, |
136 | n_output); |
137 | |
138 | if (projection_weights_ptr != nullptr) { |
139 | tensor_utils::ReductionSumVector( |
140 | projection_weights_ptr, projection_weights_row_sums, n_output, n_cell); |
141 | } |
142 | } |
143 | |
144 | inline float GetTensorScale(const TfLiteTensor* tensor) { |
145 | return tensor == nullptr ? 1.0f : tensor->params.scale; |
146 | } |
147 | |
148 | // LINT.IfChange |
149 | // Calculates a single LSTM gate. |
150 | // |
151 | // Implements the following formula: (* is matrix multiply) |
152 | // gate = activate(W_input * input + W_aux * aux_input + |
153 | // W_peephole * cell + W_recurrent * prev_output + bias) |
154 | // with layer norm: |
155 | // gate = activate(W_norm * normalize(...) + bias) // not adding bias inside |
156 | // |
157 | // Activation is sigmoid except for the "cell" gate (configurable, usually tanh) |
158 | // |
159 | // Parameters: |
160 | // Input vectors (to LSTM): | Size: | Optional? |
161 | // input | n_input | |
162 | // aux_input | n_aux_input | y (bidir LSTM) |
163 | // Input vectors (persistent states): |
164 | // output_state | n_output | |
165 | // cell_state | n_cell | |
166 | // 'Constant' inputs: |
167 | // input_to_gate_weights | n_cell * n_input | |
168 | // aux_input_to_gate_weights | n_cell * n_aux_input | y (bidir LSTM) |
169 | // recurrent_to_gate_weights | n_cell * n_output | |
170 | // cell_to_gate_weights | n_cell | y (peephole) |
171 | // gate_bias | n_cell | |
172 | // layer_norm_coefficients | n_cell | y (layer norm) |
173 | // Output vector: |
174 | // gate | n_cell | |
175 | // Scalar parameters: |
176 | // n_batch - batch size / number of vectors |
177 | // n_input, n_aux_input, n_output, n_cell - size of vectors. |
178 | // activation - activation to use. |
179 | // is_input_all_zeros, is_aux_input_all_zeros - if input vectors are all zero. |
180 | // use_layer_norm - if doing layer norm LSTM. |
181 | inline void CalculateLstmGateFloat( |
182 | const float* input, const float* input_to_gate_weights, |
183 | const float* aux_input, const float* aux_input_to_gate_weights, |
184 | const float* output_state, const float* recurrent_to_gate_weights, |
185 | const float* cell_state, const float* cell_to_gate_weights, |
186 | const float* layer_norm_coefficients, const float* gate_bias, |
187 | const int n_batch, const int n_input, const int n_aux_input, |
188 | const int n_output, const int n_cell, |
189 | const TfLiteFusedActivation activation, float* gate, |
190 | const bool is_input_all_zeros, const bool is_aux_input_all_zeros, |
191 | float* output, CpuBackendContext* context) { |
192 | const bool use_peephole = (cell_to_gate_weights != nullptr); |
193 | const bool use_layer_norm = (layer_norm_coefficients != nullptr); |
194 | |
195 | // Initialize scratch buffers with bias for regular lstm or initialize with |
196 | // zero for layer norm lstm. |
197 | if (use_layer_norm) { |
198 | std::fill_n(gate, n_cell * n_batch, 0.0f); |
199 | } else { |
200 | tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate); |
201 | } |
202 | // For each batch and cell: compute input_weight * input. |
203 | // Skip if input is all zeros. |
204 | float* accumulation_buffer = gate; |
205 | if (!is_input_all_zeros) { |
206 | MatrixBatchVectorMultiplyAccumulate(input_to_gate_weights, input, |
207 | accumulation_buffer, output, n_cell, |
208 | n_input, n_batch, context); |
209 | std::swap(accumulation_buffer, output); |
210 | } |
211 | // For each batch and cell: compute aux_input_weight * aux_input. |
212 | // Skip if auxiliary input is not available or all zeros. |
213 | if (!is_aux_input_all_zeros) { |
214 | MatrixBatchVectorMultiplyAccumulate(aux_input_to_gate_weights, aux_input, |
215 | accumulation_buffer, output, n_cell, |
216 | n_aux_input, n_batch, context); |
217 | std::swap(accumulation_buffer, output); |
218 | } |
219 | // For each batch and cell: compute recurrent_weight * output_state. |
220 | MatrixBatchVectorMultiplyAccumulate(recurrent_to_gate_weights, output_state, |
221 | accumulation_buffer, output, n_cell, |
222 | n_output, n_batch, context); |
223 | // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM) |
224 | if (use_peephole) { |
225 | tensor_utils::VectorBatchVectorCwiseProductAccumulate( |
226 | cell_to_gate_weights, n_cell, cell_state, n_batch, output); |
227 | } |
228 | // Do layer normalization (if layer norm LSTM) |
229 | if (use_layer_norm) { |
230 | tensor_utils::MeanStddevNormalization(output, output, n_cell, n_batch); |
231 | tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell, |
232 | output, n_batch, output); |
233 | tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, output); |
234 | } |
235 | // Apply activation |
236 | tensor_utils::ApplyActivationToVector(output, n_batch * n_cell, activation, |
237 | gate); |
238 | } |
239 | |
240 | // Updates the LSTM cell state, used by both float and hybrid LSTM versions. |
241 | // |
242 | // Implements the following formula: |
243 | // cell_state_new = clip(forget_gate * cell_state + input_gate * cell_gate) |
244 | // |
245 | // With CIFG LSTM, input gate is replaced by (1-forget_gate). |
246 | // |
247 | // Parameters: |
248 | // - n_batch, n_cell: sizes of vectors |
249 | // - cell_state: input/output vector, size n_batch*n_cell |
250 | // - input_gate: input vector, size n_batch*n_cell. |
251 | // - forget_gate: input/scratch vector, size n_batch*n_cell, modified with CIFG |
252 | // - cell_gate: input vector, size n_batch*n_cell. |
253 | // - use_cifg: use 1-forget_gate instead of input_gate. |
254 | // - clip: if > 0, clip the resulting cell state to [-clip, +clip]. |
255 | void UpdateLstmCellFloat(int n_batch, int n_cell, float* cell_state, |
256 | const float* input_gate, float* forget_gate, |
257 | const float* cell_gate, bool use_cifg, float clip) { |
258 | tensor_utils::VectorVectorCwiseProduct(forget_gate, cell_state, |
259 | n_batch * n_cell, cell_state); |
260 | |
261 | if (use_cifg) { |
262 | // With CIFG, input_gate = 1-forget_gate. Use the forget_gate array as |
263 | // scratch, as input_gate array is not allocated in this case. (Be careful |
264 | // not to write to the scratch before reading the forget gate data.) |
265 | float* scratch = forget_gate; |
266 | tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch); |
267 | tensor_utils::VectorVectorCwiseProductAccumulate( |
268 | cell_gate, scratch, n_batch * n_cell, cell_state); |
269 | } else { |
270 | tensor_utils::VectorVectorCwiseProductAccumulate( |
271 | cell_gate, input_gate, n_batch * n_cell, cell_state); |
272 | } |
273 | if (clip > 0.0f) { |
274 | tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip); |
275 | } |
276 | } |
277 | |
278 | // Calculates the output state tensor of an LSTM step. |
279 | // |
280 | // Implements the following formula: |
281 | // output_no_projection = output_gate .* activate(cell_state) |
282 | // (elementwise vector product) |
283 | // If no projection is used: |
284 | // output = output_state = output_no_projection |
285 | // With projection: |
286 | // output = output_state = clip(W*output_no_projection + bias) |
287 | // |
288 | // Output might not have a different 'stride' than n_batch, so we need to copy. |
289 | // |
290 | // Parameters: |
291 | // - n_batch: batches: the number of distinct vectors in each array. |
292 | // - n_cell, n_output: sizes of vectors. |
293 | // - cell_state, output_gate: input vectors, size n_batch*n_cell. |
294 | // - projection_weights, projection_weights_scale, projection_bias: |
295 | // constant inputs, describing projection matrix and bias. |
296 | // - proj_clip: if > 0, clip the output of the projection. |
297 | // - output_state: output vector, size n_batch*n_output. Must be contigous. |
298 | // - scratch: scratch area to store output_no_projection. Size n_batch*n_cell. |
299 | // - projection_bias_scratch: scratch area to store projection_bias. Size |
300 | // n_batch*n_cell. |
301 | // - context: the CpuBackendContext for use with matrix multiplications. |
302 | void CalculateLstmOutputFloat(int n_batch, int n_cell, int n_output, |
303 | const float* cell_state, const float* output_gate, |
304 | TfLiteFusedActivation activation, |
305 | const float* projection_weights, |
306 | const float* projection_bias, |
307 | const float proj_clip, float* output_state, |
308 | float* scratch, float* projection_bias_scratch, |
309 | CpuBackendContext* context) { |
310 | tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell, |
311 | activation, scratch); |
312 | tensor_utils::VectorVectorCwiseProduct(output_gate, scratch, n_batch * n_cell, |
313 | scratch); |
314 | |
315 | const bool use_projection = (projection_weights != nullptr); |
316 | const bool use_projection_bias = (projection_bias != nullptr); |
317 | |
318 | if (use_projection) { |
319 | if (use_projection_bias) { |
320 | tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch, |
321 | projection_bias_scratch); |
322 | } else { |
323 | std::fill_n(projection_bias_scratch, n_batch * n_output, 0.0f); |
324 | } |
325 | MatrixBatchVectorMultiplyAccumulate(projection_weights, scratch, |
326 | projection_bias_scratch, output_state, |
327 | n_output, n_cell, n_batch, context); |
328 | if (proj_clip > 0.0f) { |
329 | tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip); |
330 | } |
331 | } else { |
332 | std::copy_n(scratch, n_batch * n_output, output_state); |
333 | } |
334 | } |
335 | // LINT.ThenChange(../tools/optimize/calibration/builtin_logging_ops/lstm.cc,\ |
336 | // ../experimental/kernels/fp16/lstm_eval.cc) |
337 | |
338 | // Calculates a single LSTM gate, hybrid version. |
339 | // Implements the same functionality as CalculateLstmGateFloat. |
340 | void CalculateLstmGateHybrid( |
341 | // Input and weights |
342 | const int8_t* input, const float* input_sf, const int32_t* input_zp, |
343 | const int8_t* input_to_gate_weights, |
344 | const uint8_t* input_to_gate_weights_ledger, |
345 | const float input_to_gate_weights_scale, int32_t* input_to_gate_row_sums, |
346 | // Aux input and weights |
347 | const int8_t* aux_input, const float* aux_input_sf, |
348 | const int32_t* aux_input_zp, const int8_t* aux_input_to_gate_weights, |
349 | const float aux_input_to_gate_weights_scale, |
350 | int32_t* aux_input_to_gate_row_sums, |
351 | // Output state and weights |
352 | const int8_t* output_state, const float* output_state_sf, |
353 | const int32_t* output_state_zp, const int8_t* recurrent_to_gate_weights, |
354 | const uint8_t* recurrent_to_gate_weights_ledger, |
355 | const float recurrent_to_gate_weights_scale, |
356 | int32_t* recurrent_to_gate_row_sums, |
357 | // Cell state and weights (peephole LSTM) |
358 | const float* cell_state, const int8_t* cell_to_gate_weights, |
359 | const float cell_to_gate_weights_scale, |
360 | // Layer normalization coefficients (layer norm LSTM) + gate bias |
361 | const float* layer_norm_coefficients, const float* gate_bias, |
362 | // Array sizes |
363 | const int n_batch, const int n_input, const int n_aux_input, |
364 | const int n_output, const int n_cell, |
365 | const TfLiteFusedActivation activation, |
366 | // Output |
367 | float* gate, |
368 | // Parameters for performance optimizations |
369 | const bool is_input_all_zeros, const bool is_aux_input_all_zeros, |
370 | const bool is_output_state_all_zeros, bool* compute_row_sums, |
371 | CpuBackendContext* context, |
372 | // Scratch arrays |
373 | float* scratch0, // size: n_batch |
374 | float* scratch1, // size: n_cell, only used if peephole LSTM |
375 | int32_t* accum_scratch // For MatrixBatchVectorMultiplyAccumulate |
376 | ) { |
377 | const bool use_peephole = (cell_to_gate_weights != nullptr); |
378 | const bool use_layer_norm = (layer_norm_coefficients != nullptr); |
379 | |
380 | // Initialize scratch buffers with bias for regular lstm or initialize with |
381 | // zero for layer norm lstm. |
382 | if (use_layer_norm) { |
383 | std::fill_n(gate, n_cell * n_batch, 0.0f); |
384 | } else { |
385 | tensor_utils::VectorBatchVectorAssign(gate_bias, n_cell, n_batch, gate); |
386 | } |
387 | // For each batch and cell: compute input_weight * input. |
388 | // Skip if input is all zeros. |
389 | if (!is_input_all_zeros) { |
390 | if (input_to_gate_weights_ledger != nullptr) { |
391 | std::vector<float> scales(n_batch); |
392 | for (int i = 0; i < n_batch; i++) { |
393 | scales[i] = input_to_gate_weights_scale * input_sf[i]; |
394 | } |
395 | tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( |
396 | input_to_gate_weights, input_to_gate_weights_ledger, n_cell, n_input, |
397 | input, scales.data(), n_batch, gate); |
398 | |
399 | } else { |
400 | tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
401 | input_to_gate_weights, n_cell, n_input, input, |
402 | input_to_gate_weights_scale, input_sf, n_batch, gate, |
403 | /*per_channel_scale=*/nullptr, input_zp, accum_scratch, |
404 | input_to_gate_row_sums, compute_row_sums, scratch0, context); |
405 | } |
406 | } |
407 | // For each batch and cell: compute aux_input_weight * aux_input. |
408 | // Skip if auxiliary input is not available or all zeros. |
409 | if (!is_aux_input_all_zeros) { |
410 | tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
411 | aux_input_to_gate_weights, n_cell, n_aux_input, aux_input, |
412 | aux_input_to_gate_weights_scale, aux_input_sf, n_batch, gate, |
413 | /*per_channel_scale=*/nullptr, aux_input_zp, accum_scratch, |
414 | aux_input_to_gate_row_sums, compute_row_sums, scratch0, context); |
415 | } |
416 | // For each batch and cell: compute recurrent_weight * output_state. |
417 | // Skip if output state is all zeros. |
418 | if (!is_output_state_all_zeros) { |
419 | if (recurrent_to_gate_weights_ledger != nullptr) { |
420 | std::vector<float> scales(n_batch); |
421 | for (int i = 0; i < n_batch; i++) { |
422 | scales[i] = recurrent_to_gate_weights_scale * input_sf[i]; |
423 | } |
424 | tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( |
425 | recurrent_to_gate_weights, recurrent_to_gate_weights_ledger, n_cell, |
426 | n_output, output_state, scales.data(), n_batch, gate); |
427 | } else { |
428 | tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
429 | recurrent_to_gate_weights, n_cell, n_output, output_state, |
430 | recurrent_to_gate_weights_scale, output_state_sf, n_batch, gate, |
431 | /*per_channel_scale=*/nullptr, output_state_zp, accum_scratch, |
432 | recurrent_to_gate_row_sums, compute_row_sums, scratch0, context); |
433 | } |
434 | } |
435 | // For each batch and cell: compute cell_weight .* cell_state (peephole LSTM) |
436 | if (use_peephole) { |
437 | float* recovered_cell_weights = scratch1; |
438 | tensor_utils::VectorScalarMultiply(cell_to_gate_weights, n_cell, |
439 | cell_to_gate_weights_scale, |
440 | recovered_cell_weights); |
441 | tensor_utils::VectorBatchVectorCwiseProductAccumulate( |
442 | recovered_cell_weights, n_cell, cell_state, n_batch, gate); |
443 | } |
444 | // Do layer normalization (if layer norm LSTM) |
445 | if (use_layer_norm) { |
446 | tensor_utils::MeanStddevNormalization(gate, gate, n_cell, n_batch); |
447 | tensor_utils::VectorBatchVectorCwiseProduct(layer_norm_coefficients, n_cell, |
448 | gate, n_batch, gate); |
449 | tensor_utils::VectorBatchVectorAdd(gate_bias, n_cell, n_batch, gate); |
450 | } |
451 | // Apply activation |
452 | tensor_utils::ApplyActivationToVector(gate, n_cell * n_batch, activation, |
453 | gate); |
454 | } |
455 | |
456 | // Calculates the output state tensor of an LSTM step. See Float version too. |
457 | // |
458 | // Parameters: |
459 | // - n_batch: batches: the number of distinct vectors in each array. |
460 | // - n_cell, n_output: sizes of vectors. |
461 | // - cell_state, output_gate: input vectors, size n_batch*n_cell. |
462 | // - projection_weights, projection_weights_scale, projection_bias: |
463 | // constant inputs, describing projection matrix and bias. |
464 | // - proj_clip: if > 0, clip the output of the projection. |
465 | // - output_state: output vector, size n_batch*n_output. Must be contigous. |
466 | // - asymmetric_quantize_inputs: parameter to control quantization. |
467 | // - projection_weights_row_sums, compute_row_sums, context: Data for optimized |
468 | // MatrixBatchVectorMultiplyAccumulate. |
469 | // - scratch0: scratch area of size n_batch*n_cell |
470 | // - scratch1: scratch area of size n_batch*n_cell |
471 | // - scratch2: scratch area of size n_batch |
472 | // - scratch3: scratch area of size n_batch |
473 | // - scratch4: scratch area used by MatrixBatchVectorMultiplyAccumulate |
474 | void CalculateLstmOutputHybrid( |
475 | int n_batch, int n_cell, int n_output, const float* cell_state, |
476 | const float* output_gate, TfLiteFusedActivation activation, |
477 | const int8_t* projection_weights, const uint8_t* projection_weights_ledger, |
478 | float projection_weights_scale, const float* projection_bias, |
479 | const float proj_clip, float* output_state, bool asymmetric_quantize_inputs, |
480 | int32_t* projection_weights_row_sums, bool* compute_row_sums, |
481 | CpuBackendContext* context, float* scratch0, int8_t* scratch1, |
482 | float* scratch2, int32_t* scratch3, int32_t* scratch4) { |
483 | tensor_utils::ApplyActivationToVector(cell_state, n_batch * n_cell, |
484 | activation, scratch0); |
485 | tensor_utils::VectorVectorCwiseProduct(output_gate, scratch0, |
486 | n_batch * n_cell, scratch0); |
487 | |
488 | const bool use_projection = (projection_weights != nullptr); |
489 | const bool use_projection_bias = (projection_bias != nullptr); |
490 | |
491 | if (use_projection) { |
492 | if (use_projection_bias) { |
493 | tensor_utils::VectorBatchVectorAssign(projection_bias, n_output, n_batch, |
494 | output_state); |
495 | } else { |
496 | std::fill_n(output_state, n_batch * n_output, 0.0f); |
497 | } |
498 | if (!tensor_utils::IsZeroVector(scratch0, n_batch * n_cell)) { |
499 | // Save quantization and matmul computation for all zero output. |
500 | tensor_utils::BatchQuantizeFloats(scratch0, n_batch, n_cell, scratch1, |
501 | scratch2, scratch3, |
502 | asymmetric_quantize_inputs); |
503 | if (projection_weights_ledger != nullptr) { |
504 | std::vector<float> scales(n_batch); |
505 | for (int i = 0; i < n_batch; i++) { |
506 | scales[i] = projection_weights_scale * scratch2[i]; |
507 | } |
508 | tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate( |
509 | projection_weights, projection_weights_ledger, n_output, n_cell, |
510 | scratch1, scales.data(), n_batch, output_state); |
511 | } else { |
512 | tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
513 | projection_weights, n_output, n_cell, scratch1, |
514 | projection_weights_scale, scratch2, n_batch, output_state, |
515 | /*per_channel_scale=*/nullptr, scratch3, scratch4, |
516 | projection_weights_row_sums, compute_row_sums, scratch2, context); |
517 | } |
518 | } |
519 | if (proj_clip > 0.0f) { |
520 | tensor_utils::CwiseClipping(output_state, n_batch * n_output, proj_clip); |
521 | } |
522 | } else { |
523 | std::copy_n(scratch0, n_batch * n_output, output_state); |
524 | } |
525 | } |
526 | |
527 | // Calculates a single LSTM gate, int8x8_16 version. |
528 | // Implements the same functionality as CalculateLstmGateFloat. |
529 | void CalculateLstmGateInteger8x8_16( |
530 | // Input and weights |
531 | const int8_t* input, const int8_t* input_to_gate_weights, |
532 | const int32_t* input_to_gate_bias, const int32_t input_to_gate_scale_a, |
533 | const int32_t input_to_gate_scale_b, |
534 | // Output state and weights |
535 | const int8_t* output_state, const int8_t* recurrent_to_gate_weights, |
536 | const int32_t* recurrent_to_gate_bias, |
537 | const int32_t recurrent_to_gate_scale_a, |
538 | const int32_t recurrent_to_gate_scale_b, |
539 | // Cell state and weights |
540 | const int16_t* cell_state, const int16_t* cell_to_gate_weights, |
541 | const int32_t cell_to_gate_scale_a, const int32_t cell_to_gate_scale_b, |
542 | // Layer normalization parameters (layer norm LSTM) |
543 | const int16_t* layer_norm_coefficients, const int32_t* layer_norm_bias, |
544 | const int32_t layer_norm_input_scale_a, |
545 | const int32_t layer_norm_input_scale_b, |
546 | const int32_t layer_norm_variance_guard, |
547 | // Array sizes |
548 | const int n_batch, const int n_input, const int n_output, const int n_cell, |
549 | const TfLiteFusedActivation activation, |
550 | // Output |
551 | int16_t* gate, |
552 | // Parameters for performance optimizations |
553 | CpuBackendContext* context, |
554 | // Scratch arrays |
555 | int32_t* scratch5) { |
556 | const bool use_peephole = (cell_to_gate_weights != nullptr); |
557 | const bool use_layer_norm = (layer_norm_coefficients != nullptr); |
558 | |
559 | // Initialize scratch buffers with zeros. Note that unlike float and hybrid |
560 | // versions, bias is only used in layer normalization. |
561 | std::fill_n(gate, n_batch * n_cell, 0); |
562 | // For each batch and cell: compute input_weight * input. |
563 | tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
564 | input, input_to_gate_bias, input_to_gate_weights, input_to_gate_scale_a, |
565 | input_to_gate_scale_b, n_batch, n_input, n_cell, 0, scratch5, gate, |
566 | context); |
567 | // Note: no aux_input. |
568 | |
569 | // For each batch and cell: compute recurrent_weight * output_state. |
570 | tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
571 | output_state, recurrent_to_gate_bias, recurrent_to_gate_weights, |
572 | recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output, |
573 | n_cell, 0, scratch5, gate, context); |
574 | // For each batch and cell: compute cell_weight * cell_state (peephole LSTM) |
575 | if (use_peephole) { |
576 | tensor_utils::VectorBatchVectorCwiseProductAccumulate( |
577 | cell_to_gate_weights, n_output, cell_state, n_batch, |
578 | cell_to_gate_scale_a, cell_to_gate_scale_b, gate); |
579 | } |
580 | // Do layer normalization (if layer norm LSTM) |
581 | if (use_layer_norm) { |
582 | tensor_utils::ApplyLayerNorm( |
583 | gate, layer_norm_coefficients, layer_norm_bias, |
584 | layer_norm_input_scale_a, layer_norm_input_scale_b, |
585 | layer_norm_variance_guard, n_batch, n_cell, gate); |
586 | } |
587 | // Apply activation |
588 | switch (activation) { |
589 | case kTfLiteActSigmoid: |
590 | tensor_utils::ApplySigmoid(gate, n_batch, n_cell, gate); |
591 | break; |
592 | case kTfLiteActTanh: |
593 | tensor_utils::ApplyTanh(3, gate, n_batch, n_cell, gate); |
594 | break; |
595 | default: |
596 | // Only Sigmoid or Tanh is used. |
597 | TFLITE_ASSERT_FALSE; |
598 | } |
599 | } |
600 | |
601 | // Updates the LSTM cell state, used by both integer LSTM versions. |
602 | // Also see UpdateLstmCellFloat. |
603 | // |
604 | // Parameters: |
605 | // - n_batch, n_cell: sizes of vectors |
606 | // - cell_state: input/output vector, size n_batch*n_cell |
607 | // - cell_state_scale: scaling factor of cell state. |
608 | // - input_gate: input vector, size n_batch*n_cell. |
609 | // - forget_gate: input/scratch vector, size n_batch*n_cell, always modified. |
610 | // - cell_gate: input vector, size n_batch*n_cell. |
611 | // - use_cifg: use 1-forget_gate instead of input_gate. |
612 | // - clip: if > 0, clip the resulting cell state to [-clip, +clip]. |
613 | void UpdateLstmCellInteger(int n_batch, int n_cell, int16_t* cell_state, |
614 | int32_t cell_state_scale, const int16_t* input_gate, |
615 | int16_t* forget_gate, const int16_t* cell_gate, |
616 | bool use_cifg, int16_t clip) { |
617 | // Use the forget_gate array as scratch, as input_gate array is not allocated |
618 | // in CIFG case. (Be careful not to write to the scratch before reading the |
619 | // forget gate data.) |
620 | int16_t* scratch = forget_gate; |
621 | |
622 | tensor_utils::CwiseMul(forget_gate, cell_state, n_batch, n_cell, 15, |
623 | cell_state); |
624 | if (use_cifg) { |
625 | tensor_utils::Sub1Vector(forget_gate, n_batch * n_cell, scratch); |
626 | tensor_utils::CwiseMul(scratch, cell_gate, n_batch, n_cell, |
627 | 30 + cell_state_scale, scratch); |
628 | } else { |
629 | tensor_utils::CwiseMul(input_gate, cell_gate, n_batch, n_cell, |
630 | 30 + cell_state_scale, scratch); |
631 | } |
632 | tensor_utils::CwiseAdd(cell_state, scratch, n_batch, n_cell, cell_state); |
633 | |
634 | if (clip > 0) { |
635 | tensor_utils::CwiseClipping(cell_state, n_batch * n_cell, clip); |
636 | } |
637 | } |
638 | |
639 | // Calculates the output state tensor of an LSTM step. See Float and hybrid |
640 | // versions as well. |
641 | // |
642 | // Parameters: |
643 | // - n_batch: batches: the number of distinct vectors in each array. |
644 | // - n_cell, n_output: sizes of vectors. |
645 | // - cell_state, output_gate: input vectors, size n_batch*n_cell. |
646 | // - cell_state_scale: scaling of cell_state. |
647 | // - hidden_scale_[a|b]: effective scale of cell_state.*output_gate |
648 | // - hidden_zp: zero_point for cell_state.*output_gate |
649 | // - projection_weights, proj_scale_[a|b], projection_bias: |
650 | // constant inputs, describing projection matrix and bias. |
651 | // - output_state_zp: zero point of output_state. (Input, calibrated value.) |
652 | // - quantized_proj_clip: if > 0, clip the output of the projection. |
653 | // - output_state: output vector, size n_batch*n_output. Must be contigous. |
654 | // - context: data for optimized MatrixBatchVectorMultiplyAccumulate. |
655 | // - scratch0: scratch area of size n_batch*n_cell |
656 | // - scratch1: scratch area of size n_batch*n_cell |
657 | // - scratch2: scratch area used by MatrixBatchVectorMultiplyAccumulate |
658 | void CalculateLstmOutputInteger8x8_16( |
659 | int n_batch, int n_cell, int n_output, const int16_t* cell_state, |
660 | int32_t cell_state_scale, const int16_t* output_gate, |
661 | int32_t hidden_scale_a, int32_t hidden_scale_b, int32_t hidden_zp, |
662 | const int8_t* projection_weights, int32_t proj_scale_a, |
663 | int32_t proj_scale_b, const int32_t* projection_bias, |
664 | int32_t output_state_zp, int8_t quantized_proj_clip, int8_t* output_state, |
665 | CpuBackendContext* context, int16_t* scratch0, int8_t* scratch1, |
666 | int32_t* scratch2) { |
667 | // Note: unlike float/hybrid, the activation is always Tanh. |
668 | tensor_utils::ApplyTanh(15 + cell_state_scale, cell_state, n_batch, n_cell, |
669 | scratch0); |
670 | tensor_utils::CwiseMul(output_gate, scratch0, hidden_scale_a, hidden_scale_b, |
671 | n_batch, n_cell, hidden_zp, scratch1); |
672 | |
673 | const bool use_projection = (projection_weights != nullptr); |
674 | |
675 | if (use_projection) { |
676 | // Note: no bias like in float/hybrid |
677 | std::fill_n(output_state, n_batch * n_output, 0); |
678 | tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
679 | scratch1, projection_bias, projection_weights, proj_scale_a, |
680 | proj_scale_b, n_batch, n_cell, n_output, output_state_zp, scratch2, |
681 | output_state, context); |
682 | if (quantized_proj_clip > 0) { |
683 | tensor_utils::CwiseClipping(output_state, n_batch * n_output, |
684 | quantized_proj_clip); |
685 | } |
686 | } else { |
687 | std::copy_n(scratch1, n_batch * n_output, output_state); |
688 | } |
689 | } |
690 | |
691 | // Calculates a single LSTM gate, int8x8_8 version. |
692 | // Implements the same functionality as CalculateLstmGateFloat. |
693 | void CalculateLstmGateInteger8x8_8( |
694 | // Inputs and weights |
695 | const int8_t* input, int32_t input_zp, const int8_t* input_to_gate_weight, |
696 | const int32_t input_to_gate_scale_a, const int32_t input_to_gate_scale_b, |
697 | const int32_t input_times_weights_scale_a, |
698 | const int32_t input_times_weights_scale_b, |
699 | const int32_t input_times_weights_zp, |
700 | // Output state and weights |
701 | const int8_t* output_state, const int32_t output_state_zp, |
702 | const int8_t* recurrent_to_gate_weight, |
703 | const int32_t recurrent_to_gate_scale_a, |
704 | const int32_t recurrent_to_gate_scale_b, |
705 | const int32_t output_state_times_weights_scale_a, |
706 | const int32_t output_state_times_weights_scale_b, |
707 | const int32_t output_state_times_weights_zp, |
708 | // Layer normalization parameters (layer norm LSTM) |
709 | const int16_t* layer_norm_gate_weight, |
710 | const int32_t layer_norm_gate_scale_a, |
711 | const int32_t layer_norm_gate_scale_b, const int32_t* gate_bias, |
712 | // Array sizes |
713 | const int n_batch, const int n_input, const int n_output, const int n_cell, |
714 | const TfLiteFusedActivation activation, |
715 | // Output |
716 | int16_t* gate, |
717 | // Scratch arrays, both sized n_batch*n_cell |
718 | int8_t* scratch0, int8_t* scratch1) { |
719 | // Multiply input * input_weights => scratch0 |
720 | tensor_utils::MatrixBatchVectorMultiply( |
721 | input, input_zp, input_to_gate_weight, input_to_gate_scale_a, |
722 | input_to_gate_scale_b, n_batch, n_input, n_cell, scratch0, |
723 | input_times_weights_zp); |
724 | // Multiply output_state * recurrent_weights => scratch1 |
725 | tensor_utils::MatrixBatchVectorMultiply( |
726 | output_state, output_state_zp, recurrent_to_gate_weight, |
727 | recurrent_to_gate_scale_a, recurrent_to_gate_scale_b, n_batch, n_output, |
728 | n_cell, scratch1, output_state_times_weights_zp); |
729 | // Add scratch0 + scratch1 => gate |
730 | tensor_utils::TwoGateSaturatingAdd( |
731 | scratch0, input_times_weights_zp, scratch1, output_state_times_weights_zp, |
732 | input_times_weights_scale_a, input_times_weights_scale_b, |
733 | output_state_times_weights_scale_a, output_state_times_weights_scale_b, |
734 | n_batch, n_cell, gate); |
735 | // Apply layer normalization. |
736 | tensor_utils::ApplyLayerNormFloat( |
737 | gate, layer_norm_gate_weight, layer_norm_gate_scale_a, |
738 | layer_norm_gate_scale_b, gate_bias, n_batch, n_cell, gate); |
739 | // Apply activation. |
740 | switch (activation) { |
741 | case kTfLiteActSigmoid: |
742 | tensor_utils::ApplySigmoidFloat(gate, n_batch, n_cell, gate); |
743 | break; |
744 | case kTfLiteActTanh: |
745 | tensor_utils::ApplyTanhFloat(gate, n_batch, n_cell, -12, gate); |
746 | break; |
747 | default: |
748 | // Only Sigmoid or Tanh is used. |
749 | TFLITE_ASSERT_FALSE; |
750 | } |
751 | } |
752 | |
753 | // Calculates the output state tensor of an LSTM step. See Float and hybrid |
754 | // versions as well. |
755 | // |
756 | // Parameters: |
757 | // - n_batch: batches: the number of distinct vectors in each array. |
758 | // - n_cell, n_output: sizes of vectors. |
759 | // - cell_state, output_gate: input vectors, size n_batch*n_cell. |
760 | // - projection_weights, proj_scale_[a|b], projection_bias: |
761 | // constant inputs, describing projection matrix and bias. |
762 | // - output_state_zp: zero point of the output state. |
763 | // - quantized_proj_clip: if > 0, clip the output of the projection. |
764 | // - output_state: output vector, size n_batch*n_output. Must be contigous. |
765 | // - scratch: scratch area of size n_batch*n_cell |
766 | void CalculateLstmOutputInteger8x8_8( |
767 | int n_batch, int n_cell, int n_output, const int16_t* cell_state, |
768 | const int16_t* output_gate, const int8_t* projection_weights, |
769 | int32_t proj_scale_a, int32_t proj_scale_b, const int32_t* projection_bias, |
770 | int32_t output_state_zp, int32_t quantized_proj_clip, int8_t* output_state, |
771 | int16_t* scratch) { |
772 | // Note: unlike float/hybrid, the activation is always Tanh. |
773 | tensor_utils::ApplyTanhFloat(cell_state, n_batch, n_cell, -15, scratch); |
774 | tensor_utils::CwiseMul(output_gate, scratch, n_batch, n_cell, 15 + 15 - 15, |
775 | scratch); |
776 | // Note: no bias like in float/hybrid |
777 | tensor_utils::MatrixBatchVectorMultiply( |
778 | scratch, projection_weights, proj_scale_a, proj_scale_b, projection_bias, |
779 | n_batch, n_cell, n_output, output_state_zp, output_state); |
780 | if (quantized_proj_clip > 0) { |
781 | tensor_utils::CwiseClipping(output_state, n_batch * n_output, |
782 | quantized_proj_clip); |
783 | } |
784 | } |
785 | |
786 | // Performs an LSTM batch inference step for input specified by input_ptr. |
787 | // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and |
788 | // biases (*_bias_ptr), and buffers (*_scratch), along with additional |
789 | // parameters: |
790 | // - params: various LSTM params including activation, clipping, etc., |
791 | // - n_batch: size of batch, |
792 | // - n_cell: number of cells (or units), |
793 | // - n_input: the input size, |
794 | // - n_aux_input: the auxiliary input size. |
795 | // - n_output: the output size. |
796 | // - output_batch_leading_dim: the leading dimension of the output buffer. |
797 | // - context: the CpuBackendContext for use with matrix multiplications. |
798 | // |
799 | // Input of size 'n_batch * n_input': |
800 | // input_ptr |
801 | // Input of size 'n_batch * n_aux_input': |
802 | // aux_input_ptr - optional (can be nullptr) |
803 | // |
804 | // LSTM weights: |
805 | // Input weights of size 'n_cell * n_input': |
806 | // input_to_input_weights - optional |
807 | // input_to_forget_weights |
808 | // input_to_cell_weights |
809 | // input_to_output_weights |
810 | // Auxiliary input weights of size 'n_cell * n_aux_input': |
811 | // aux_input_to_input_weights - optional |
812 | // aux_input_to_forget_weights - optional |
813 | // aux_input_to_cell_weights - optional |
814 | // aux_input_to_output_weights - optional |
815 | // Recurrent weights of size 'n_cell * n_output': |
816 | // recurrent_to_input_weights - optional |
817 | // recurrent_to_forget_weights |
818 | // recurrent_to_cell_weights |
819 | // recurrent_to_input_weights |
820 | // Peephole weights of size 'n_cell', representing diagonal matrices. |
821 | // cell_to_input_weights - optional |
822 | // cell_to_cell_weights - optional |
823 | // cell_to_output_weights - optional |
824 | // Projection weights of size 'n_output * n_cell' |
825 | // projection_weights_ptr - optional |
826 | // Gate biases of size 'n_cell': |
827 | // input_gate_bias_ptr - optional |
828 | // forget_gate_bias_ptr |
829 | // cell_gate_bias_ptr |
830 | // output_gate_bias_ptr |
831 | // |
832 | // Layer norm coefficients of size 'n_cell', representing diagonal matrices. |
833 | // input_layer_norm_coefficients_ptr - optional |
834 | // forget_layer_norm_coefficients_ptr - optional |
835 | // cell_layer_norm_coefficients_ptr - optional |
836 | // output_layer_norm_coefficients_ptr - optional |
837 | // |
838 | // The pointers to the cell and output state and the output are updated. |
839 | // |
840 | // The pointers input_ptr, aux_input_ptr, and output_ptr point to data aligned |
841 | // in batch_major order, and each step processes batch_size many inputs from |
842 | // input_ptr, and updates batch_size many cell and output states. |
843 | // |
844 | // The output_batch_dim is output.shape[-1], i.e. the outermost dimension of the |
845 | // output tensor, and in most cases will be equal to n_output. It is usually not |
846 | // when we want to store the LSTM output into a slice of the output tensor, e.g. |
847 | // for bidirectional LSTMs with merge_outputs. In this case, the batched |
848 | // operations cannot be used since they assume that the batched outputs are |
849 | // contiguous, and we manually loop over the batched outputs. |
850 | // LINT.IfChange |
851 | inline void LstmStepFloat( |
852 | const float* input_ptr, const float* input_to_input_weights_ptr, |
853 | const float* input_to_forget_weights_ptr, |
854 | const float* input_to_cell_weights_ptr, |
855 | const float* input_to_output_weights_ptr, const float* aux_input_ptr, |
856 | const float* aux_input_to_input_weights_ptr, |
857 | const float* aux_input_to_forget_weights_ptr, |
858 | const float* aux_input_to_cell_weights_ptr, |
859 | const float* aux_input_to_output_weights_ptr, |
860 | const float* recurrent_to_input_weights_ptr, |
861 | const float* recurrent_to_forget_weights_ptr, |
862 | const float* recurrent_to_cell_weights_ptr, |
863 | const float* recurrent_to_output_weights_ptr, |
864 | const float* cell_to_input_weights_ptr, |
865 | const float* cell_to_forget_weights_ptr, |
866 | const float* cell_to_output_weights_ptr, |
867 | const float* input_layer_norm_coefficients_ptr, |
868 | const float* forget_layer_norm_coefficients_ptr, |
869 | const float* cell_layer_norm_coefficients_ptr, |
870 | const float* output_layer_norm_coefficients_ptr, |
871 | const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr, |
872 | const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr, |
873 | const float* projection_weights_ptr, const float* projection_bias_ptr, |
874 | const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, |
875 | int n_aux_input, int n_output, int output_batch_leading_dim, |
876 | float* output_state_ptr, float* cell_state_ptr, float* scratch0, |
877 | float* scratch1, float* scratch2, float* scratch3, float* scratch4, |
878 | float* output_ptr, CpuBackendContext* context) { |
879 | ruy::profiler::ScopeLabel label("LstmStepFloat" ); |
880 | // Since we have already checked that weights are all there or none, we can |
881 | // check the existence of only one to the get the condition. |
882 | const bool use_cifg = (input_to_input_weights_ptr == nullptr); |
883 | |
884 | // Make named scratch buffers. |
885 | float* input_gate_scratch = scratch0; |
886 | float* forget_gate_scratch = scratch1; |
887 | float* cell_gate_scratch = scratch2; |
888 | float* output_gate_scratch = scratch3; |
889 | float* accumulation_scratch_buffer = scratch4; |
890 | |
891 | // Check if inputs are all zeros so we can skip some computations. |
892 | const bool is_input_all_zeros = |
893 | tensor_utils::IsZeroVector(input_ptr, n_batch * n_input); |
894 | const bool is_aux_input_all_zeros = |
895 | (aux_input_ptr == nullptr || |
896 | tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)); |
897 | |
898 | if (!use_cifg) { |
899 | // Calculate the input gate. (If not CIFG.) |
900 | CalculateLstmGateFloat(input_ptr, input_to_input_weights_ptr, aux_input_ptr, |
901 | aux_input_to_input_weights_ptr, output_state_ptr, |
902 | recurrent_to_input_weights_ptr, |
903 | |
904 | cell_state_ptr, cell_to_input_weights_ptr, |
905 | input_layer_norm_coefficients_ptr, |
906 | input_gate_bias_ptr, n_batch, n_input, n_aux_input, |
907 | n_output, n_cell, |
908 | /*activation=*/kTfLiteActSigmoid, input_gate_scratch, |
909 | is_input_all_zeros, is_aux_input_all_zeros, |
910 | accumulation_scratch_buffer, context); |
911 | } |
912 | // Calculate the forget gate. |
913 | CalculateLstmGateFloat( |
914 | input_ptr, input_to_forget_weights_ptr, aux_input_ptr, |
915 | aux_input_to_forget_weights_ptr, output_state_ptr, |
916 | recurrent_to_forget_weights_ptr, |
917 | |
918 | cell_state_ptr, cell_to_forget_weights_ptr, |
919 | forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch, |
920 | n_input, n_aux_input, n_output, n_cell, |
921 | /*activation=*/kTfLiteActSigmoid, forget_gate_scratch, is_input_all_zeros, |
922 | is_aux_input_all_zeros, accumulation_scratch_buffer, context); |
923 | // Calculate the cell update gate. |
924 | CalculateLstmGateFloat( |
925 | input_ptr, input_to_cell_weights_ptr, aux_input_ptr, |
926 | aux_input_to_cell_weights_ptr, output_state_ptr, |
927 | recurrent_to_cell_weights_ptr, |
928 | |
929 | /*cell_state=*/nullptr, |
930 | /*cell_to_gate_weights=*/nullptr, cell_layer_norm_coefficients_ptr, |
931 | cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, |
932 | params->activation, cell_gate_scratch, is_input_all_zeros, |
933 | is_aux_input_all_zeros, accumulation_scratch_buffer, context); |
934 | // Update the cell state. |
935 | UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, |
936 | forget_gate_scratch, cell_gate_scratch, use_cifg, |
937 | params->cell_clip); |
938 | // Calculate output gate. |
939 | CalculateLstmGateFloat( |
940 | input_ptr, input_to_output_weights_ptr, aux_input_ptr, |
941 | aux_input_to_output_weights_ptr, output_state_ptr, |
942 | recurrent_to_output_weights_ptr, |
943 | |
944 | cell_state_ptr, cell_to_output_weights_ptr, |
945 | output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch, |
946 | n_input, n_aux_input, n_output, n_cell, |
947 | /*activation=*/kTfLiteActSigmoid, output_gate_scratch, is_input_all_zeros, |
948 | is_aux_input_all_zeros, accumulation_scratch_buffer, context); |
949 | // Update the output state. |
950 | CalculateLstmOutputFloat(n_batch, n_cell, n_output, cell_state_ptr, |
951 | output_gate_scratch, params->activation, |
952 | projection_weights_ptr, projection_bias_ptr, |
953 | params->proj_clip, output_state_ptr, scratch2, |
954 | accumulation_scratch_buffer, context); |
955 | // Copy output state to the output. Note that the output's rows may not be |
956 | // contiguous (output_batch_leading_dim != n_output). |
957 | for (int b = 0; b < n_batch; b++) { |
958 | std::copy_n(output_state_ptr + b * n_output, n_output, |
959 | output_ptr + b * output_batch_leading_dim); |
960 | } |
961 | } |
962 | // LINT.ThenChange(../tools/optimize/calibration/builtin_logging_ops/lstm.cc,\ |
963 | // ../experimental/kernels/fp16/lstm_eval.cc) |
964 | |
965 | // Same as above but with quantized weight matrices. In detail: |
966 | // Input of size 'n_batch * n_input': |
967 | // input_ptr |
968 | // Input of size 'n_batch * n_aux_input': |
969 | // aux_input_ptr - optional (can be nullptr) |
970 | // |
971 | // LSTM weights: |
972 | // Quantized input weights of size 'n_cell * n_input': |
973 | // input_to_input_weights - optional |
974 | // input_to_forget_weights |
975 | // input_to_cell_weights |
976 | // input_to_input_weights |
977 | // Quantized auxiliary input weights of size 'n_cell * n_aux_input': |
978 | // aux_input_to_input_weights - optional |
979 | // aux_input_to_forget_weights - optional |
980 | // aux_input_to_cell_weights - optional |
981 | // aux_input_to_output_weights - optional |
982 | // Quantized recurrent weights of size 'n_cell * n_output': |
983 | // recurrent_to_input_weights - optional |
984 | // recurrent_to_forget_weights |
985 | // recurrent_to_cell_weights |
986 | // recurrent_to_input_weights |
987 | // Quantized peephole weights of size 'n_cell', representing diagonal matrices. |
988 | // cell_to_input_weights - optional |
989 | // cell_to_cell_weights - optional |
990 | // cell_to_output_weights - optional |
991 | // Quantized projection weights of size 'n_output * n_cell' |
992 | // projection_weights_ptr - optional |
993 | // Weight scales (scalars) for each of the weights above. |
994 | // input_to_input_weights_scale - optional |
995 | // input_to_forget_weights_scale |
996 | // input_to_cell_weights_scale |
997 | // input_to_output_weights_scale |
998 | // aux_input_to_input_weights_scale - optional |
999 | // aux_input_to_forget_weights_scale - optional |
1000 | // aux_input_to_cell_weights_scale - optional |
1001 | // aux_input_to_output_weights_scale - optional |
1002 | // recurrent_to_input_weights_scale - optional |
1003 | // recurrent_to_forget_weights_scale |
1004 | // recurrent_to_cell_weights_scale |
1005 | // recurrent_to_output_weights_scale |
1006 | // cell_to_input_weights_scale, |
1007 | // cell_to_forget_weights_scale, |
1008 | // cell_to_output_weights_scale, |
1009 | // projection_weights_scale - optional |
1010 | // Gate biases of size 'n_cell': |
1011 | // input_gate_bias_ptr - optional |
1012 | // forget_gate_bias_ptr |
1013 | // cell_gate_bias_ptr |
1014 | // output_gate_bias_ptr |
1015 | // |
1016 | // Layer norm coefficients of size 'n_cell', representing diagonal matrices. |
1017 | // input_layer_norm_coefficients_ptr - optional |
1018 | // forget_layer_norm_coefficients_ptr - optional |
1019 | // cell_layer_norm_coefficients_ptr - optional |
1020 | // output_layer_norm_coefficients_ptr - optional |
1021 | // |
1022 | // Temporary pre-allocated storage for quantized values: |
1023 | // quantized_input_ptr (same size as input_ptr) |
1024 | // quantized_output_state_ptr (same size as output_state_ptr) |
1025 | // quantized_output_scratch (same size as cell_state_ptr) |
1026 | // Temporary pre-allocated storage for recovered values: |
1027 | // recovered_cell_weights (same size as cell_to_*_weights) |
1028 | // |
1029 | // Outputs: |
1030 | // output_state_ptr - size 'n_batch * n_output' |
1031 | // cell_state_ptr - size 'n_batch * n_cell' |
1032 | // output_ptr - size 'n_batch * output_batch_leading_dim' |
1033 | inline void LstmStepHybrid( |
1034 | const float* input_ptr, const int8_t* input_to_input_weights_ptr, |
1035 | const uint8_t* input_to_input_weights_ledger_ptr, |
1036 | float input_to_input_weights_scale, |
1037 | const int8_t* input_to_forget_weights_ptr, |
1038 | const uint8_t* input_to_forget_weights_ledger_ptr, |
1039 | float input_to_forget_weights_scale, |
1040 | const int8_t* input_to_cell_weights_ptr, |
1041 | const uint8_t* input_to_cell_weights_ledger_ptr, |
1042 | float input_to_cell_weights_scale, |
1043 | const int8_t* input_to_output_weights_ptr, |
1044 | const uint8_t* input_to_output_weights_ledger_ptr, |
1045 | float input_to_output_weights_scale, const float* aux_input_ptr, |
1046 | const int8_t* aux_input_to_input_weights_ptr, |
1047 | float aux_input_to_input_weights_scale, |
1048 | const int8_t* aux_input_to_forget_weights_ptr, |
1049 | float aux_input_to_forget_weights_scale, |
1050 | const int8_t* aux_input_to_cell_weights_ptr, |
1051 | float aux_input_to_cell_weights_scale, |
1052 | const int8_t* aux_input_to_output_weights_ptr, |
1053 | float aux_input_to_output_weights_scale, |
1054 | const int8_t* recurrent_to_input_weights_ptr, |
1055 | const uint8_t* recurrent_to_input_weights_ledger_ptr, |
1056 | float recurrent_to_input_weights_scale, |
1057 | const int8_t* recurrent_to_forget_weights_ptr, |
1058 | const uint8_t* recurrent_to_forget_weights_ledger_ptr, |
1059 | float recurrent_to_forget_weights_scale, |
1060 | const int8_t* recurrent_to_cell_weights_ptr, |
1061 | const uint8_t* recurrent_to_cell_weights_ledger_ptr, |
1062 | float recurrent_to_cell_weights_scale, |
1063 | const int8_t* recurrent_to_output_weights_ptr, |
1064 | const uint8_t* recurrent_to_output_weights_ledger_ptr, |
1065 | float recurrent_to_output_weights_scale, |
1066 | const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, |
1067 | const int8_t* cell_to_forget_weights_ptr, |
1068 | float cell_to_forget_weights_scale, |
1069 | const int8_t* cell_to_output_weights_ptr, |
1070 | float cell_to_output_weights_scale, |
1071 | const float* input_layer_norm_coefficients_ptr, |
1072 | const float* forget_layer_norm_coefficients_ptr, |
1073 | const float* cell_layer_norm_coefficients_ptr, |
1074 | const float* output_layer_norm_coefficients_ptr, |
1075 | const float* input_gate_bias_ptr, const float* forget_gate_bias_ptr, |
1076 | const float* cell_gate_bias_ptr, const float* output_gate_bias_ptr, |
1077 | const int8_t* projection_weights_ptr, |
1078 | const uint8_t* projection_weights_ledger_ptr, |
1079 | float projection_weights_scale, const float* projection_bias_ptr, |
1080 | const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, |
1081 | int n_aux_input, int n_output, int output_batch_leading_dim, |
1082 | float* scratch0, float* scratch1, float* scratch2, float* scratch3, |
1083 | float* input_sf, float* aux_input_sf, float* output_state_sf, |
1084 | float* scaling_factors_scratch, float* recovered_cell_weights, |
1085 | int8_t* quantized_input_ptr, int8_t* quantized_aux_input_ptr, |
1086 | int8_t* quantized_output_state_ptr, int8_t* quantized_output_scratch, |
1087 | float* output_state_ptr, float* cell_state_ptr, int32_t* accum_scratch_ptr, |
1088 | float* output_ptr, int32_t* input_zp, int32_t* aux_input_zp, |
1089 | int32_t* output_state_zp, int32_t* row_sums, int row_sums_size, |
1090 | bool* compute_row_sums, bool asymmetric_quantize_inputs, |
1091 | CpuBackendContext* context) { |
1092 | ruy::profiler::ScopeLabel label("LstmStepHybrid" ); |
1093 | // Since we have already checked that weights are all there or none, we |
1094 | // can check the existence of only one to the get the condition. |
1095 | const bool use_cifg = (input_to_input_weights_ptr == nullptr); |
1096 | // Make named scratch buffers for the different gates. |
1097 | float* input_gate_scratch = scratch0; |
1098 | float* forget_gate_scratch = scratch1; |
1099 | float* cell_gate_scratch = scratch2; |
1100 | float* output_gate_scratch = scratch3; |
1101 | |
1102 | int32_t* input_to_input_row_sums = nullptr; |
1103 | int32_t* input_to_forget_row_sums = nullptr; |
1104 | int32_t* input_to_cell_row_sums = nullptr; |
1105 | int32_t* input_to_output_row_sums = nullptr; |
1106 | int32_t* aux_input_to_input_row_sums = nullptr; |
1107 | int32_t* aux_input_to_forget_row_sums = nullptr; |
1108 | int32_t* aux_input_to_cell_row_sums = nullptr; |
1109 | int32_t* aux_input_to_output_row_sums = nullptr; |
1110 | int32_t* recurrent_to_input_row_sums = nullptr; |
1111 | int32_t* recurrent_to_forget_row_sums = nullptr; |
1112 | int32_t* recurrent_to_cell_row_sums = nullptr; |
1113 | int32_t* recurrent_to_output_row_sums = nullptr; |
1114 | int32_t* projection_weights_row_sums = nullptr; |
1115 | |
1116 | if (asymmetric_quantize_inputs) { |
1117 | int num_row_sums = use_cifg ? 6 : 8; |
1118 | if (aux_input_ptr != nullptr) { |
1119 | num_row_sums += use_cifg ? 3 : 4; |
1120 | } |
1121 | if (projection_weights_ptr != nullptr) { |
1122 | num_row_sums += ceil(static_cast<float>(n_output) / n_cell); |
1123 | } |
1124 | TF_LITE_ASSERT(row_sums_size == num_row_sums); |
1125 | input_to_input_row_sums = row_sums; |
1126 | input_to_forget_row_sums = |
1127 | use_cifg ? input_to_input_row_sums : input_to_input_row_sums + n_cell; |
1128 | input_to_cell_row_sums = input_to_forget_row_sums + n_cell; |
1129 | input_to_output_row_sums = input_to_cell_row_sums + n_cell; |
1130 | if (aux_input_ptr != nullptr) { |
1131 | aux_input_to_input_row_sums = input_to_output_row_sums + n_cell; |
1132 | aux_input_to_forget_row_sums = use_cifg |
1133 | ? aux_input_to_input_row_sums |
1134 | : aux_input_to_input_row_sums + n_cell; |
1135 | aux_input_to_cell_row_sums = aux_input_to_forget_row_sums + n_cell; |
1136 | aux_input_to_output_row_sums = aux_input_to_cell_row_sums + n_cell; |
1137 | } |
1138 | recurrent_to_input_row_sums = aux_input_ptr |
1139 | ? aux_input_to_output_row_sums + n_cell |
1140 | : input_to_output_row_sums + n_cell; |
1141 | recurrent_to_forget_row_sums = use_cifg |
1142 | ? recurrent_to_input_row_sums |
1143 | : recurrent_to_input_row_sums + n_cell; |
1144 | recurrent_to_cell_row_sums = recurrent_to_forget_row_sums + n_cell; |
1145 | recurrent_to_output_row_sums = recurrent_to_cell_row_sums + n_cell; |
1146 | if (projection_weights_ptr != nullptr) { |
1147 | projection_weights_row_sums = recurrent_to_output_row_sums + n_cell; |
1148 | } |
1149 | if (*compute_row_sums) { |
1150 | ComputeRowSums( |
1151 | input_to_input_row_sums, input_to_forget_row_sums, |
1152 | input_to_cell_row_sums, input_to_output_row_sums, |
1153 | aux_input_to_input_row_sums, aux_input_to_forget_row_sums, |
1154 | aux_input_to_cell_row_sums, aux_input_to_output_row_sums, |
1155 | recurrent_to_input_row_sums, recurrent_to_forget_row_sums, |
1156 | recurrent_to_cell_row_sums, recurrent_to_output_row_sums, |
1157 | projection_weights_row_sums, row_sums, n_cell, n_input, n_aux_input, |
1158 | n_output, input_to_input_weights_ptr, input_to_forget_weights_ptr, |
1159 | input_to_cell_weights_ptr, input_to_output_weights_ptr, |
1160 | aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr, |
1161 | aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr, |
1162 | recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, |
1163 | recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, |
1164 | projection_weights_ptr, use_cifg, aux_input_ptr); |
1165 | *compute_row_sums = false; |
1166 | } |
1167 | } |
1168 | |
1169 | // Check if inputs are all zeros so we can skip some computations. |
1170 | const bool is_input_all_zeros = |
1171 | tensor_utils::IsZeroVector(input_ptr, n_batch * n_input); |
1172 | const bool is_aux_input_all_zeros = |
1173 | (aux_input_ptr == nullptr || |
1174 | tensor_utils::IsZeroVector(aux_input_ptr, n_batch * n_aux_input)); |
1175 | const bool is_output_state_all_zeros = |
1176 | tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output); |
1177 | // Quantize inputs. |
1178 | if (!is_input_all_zeros) { |
1179 | tensor_utils::BatchQuantizeFloats(input_ptr, n_batch, n_input, |
1180 | quantized_input_ptr, input_sf, input_zp, |
1181 | asymmetric_quantize_inputs); |
1182 | } |
1183 | if (!is_aux_input_all_zeros) { |
1184 | tensor_utils::BatchQuantizeFloats(aux_input_ptr, n_batch, n_aux_input, |
1185 | quantized_aux_input_ptr, aux_input_sf, |
1186 | aux_input_zp, asymmetric_quantize_inputs); |
1187 | } |
1188 | if (!is_output_state_all_zeros) { |
1189 | tensor_utils::BatchQuantizeFloats( |
1190 | output_state_ptr, n_batch, n_output, quantized_output_state_ptr, |
1191 | output_state_sf, output_state_zp, asymmetric_quantize_inputs); |
1192 | } |
1193 | if (!use_cifg) { |
1194 | // Calculate the input gate. (If not CIFG.) |
1195 | CalculateLstmGateHybrid( |
1196 | quantized_input_ptr, input_sf, input_zp, input_to_input_weights_ptr, |
1197 | input_to_input_weights_ledger_ptr, input_to_input_weights_scale, |
1198 | input_to_input_row_sums, quantized_aux_input_ptr, aux_input_sf, |
1199 | aux_input_zp, aux_input_to_input_weights_ptr, |
1200 | aux_input_to_input_weights_scale, aux_input_to_input_row_sums, |
1201 | quantized_output_state_ptr, output_state_sf, output_state_zp, |
1202 | recurrent_to_input_weights_ptr, recurrent_to_input_weights_ledger_ptr, |
1203 | recurrent_to_input_weights_scale, recurrent_to_input_row_sums, |
1204 | cell_state_ptr, cell_to_input_weights_ptr, cell_to_input_weights_scale, |
1205 | input_layer_norm_coefficients_ptr, input_gate_bias_ptr, n_batch, |
1206 | n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid, |
1207 | input_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros, |
1208 | is_output_state_all_zeros, compute_row_sums, context, |
1209 | scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr); |
1210 | } |
1211 | // Calculate the forget gate. |
1212 | CalculateLstmGateHybrid( |
1213 | quantized_input_ptr, input_sf, input_zp, input_to_forget_weights_ptr, |
1214 | input_to_forget_weights_ledger_ptr, input_to_forget_weights_scale, |
1215 | input_to_forget_row_sums, quantized_aux_input_ptr, aux_input_sf, |
1216 | aux_input_zp, aux_input_to_forget_weights_ptr, |
1217 | aux_input_to_forget_weights_scale, aux_input_to_forget_row_sums, |
1218 | quantized_output_state_ptr, output_state_sf, output_state_zp, |
1219 | recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_ledger_ptr, |
1220 | recurrent_to_forget_weights_scale, recurrent_to_forget_row_sums, |
1221 | cell_state_ptr, cell_to_forget_weights_ptr, cell_to_forget_weights_scale, |
1222 | forget_layer_norm_coefficients_ptr, forget_gate_bias_ptr, n_batch, |
1223 | n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid, |
1224 | forget_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros, |
1225 | is_output_state_all_zeros, compute_row_sums, context, |
1226 | scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr); |
1227 | // Calculate the cell update gate. |
1228 | CalculateLstmGateHybrid( |
1229 | quantized_input_ptr, input_sf, input_zp, input_to_cell_weights_ptr, |
1230 | input_to_cell_weights_ledger_ptr, input_to_cell_weights_scale, |
1231 | input_to_cell_row_sums, quantized_aux_input_ptr, aux_input_sf, |
1232 | aux_input_zp, aux_input_to_cell_weights_ptr, |
1233 | aux_input_to_cell_weights_scale, aux_input_to_cell_row_sums, |
1234 | quantized_output_state_ptr, output_state_sf, output_state_zp, |
1235 | recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_ledger_ptr, |
1236 | recurrent_to_cell_weights_scale, recurrent_to_cell_row_sums, |
1237 | /*cell_state=*/nullptr, /*cell_to_gate_weights=*/nullptr, |
1238 | /*cell_to_gate_weights_scale=*/0.0f, cell_layer_norm_coefficients_ptr, |
1239 | cell_gate_bias_ptr, n_batch, n_input, n_aux_input, n_output, n_cell, |
1240 | params->activation, cell_gate_scratch, is_input_all_zeros, |
1241 | is_aux_input_all_zeros, is_output_state_all_zeros, compute_row_sums, |
1242 | context, scaling_factors_scratch, recovered_cell_weights, |
1243 | accum_scratch_ptr); |
1244 | // Update the cell state. |
1245 | UpdateLstmCellFloat(n_batch, n_cell, cell_state_ptr, input_gate_scratch, |
1246 | forget_gate_scratch, cell_gate_scratch, use_cifg, |
1247 | params->cell_clip); |
1248 | // Calculate the output gate. |
1249 | CalculateLstmGateHybrid( |
1250 | quantized_input_ptr, input_sf, input_zp, input_to_output_weights_ptr, |
1251 | input_to_output_weights_ledger_ptr, input_to_output_weights_scale, |
1252 | input_to_output_row_sums, quantized_aux_input_ptr, aux_input_sf, |
1253 | aux_input_zp, aux_input_to_output_weights_ptr, |
1254 | aux_input_to_output_weights_scale, aux_input_to_output_row_sums, |
1255 | quantized_output_state_ptr, output_state_sf, output_state_zp, |
1256 | recurrent_to_output_weights_ptr, recurrent_to_output_weights_ledger_ptr, |
1257 | recurrent_to_output_weights_scale, recurrent_to_output_row_sums, |
1258 | cell_state_ptr, cell_to_output_weights_ptr, cell_to_output_weights_scale, |
1259 | output_layer_norm_coefficients_ptr, output_gate_bias_ptr, n_batch, |
1260 | n_input, n_aux_input, n_output, n_cell, kTfLiteActSigmoid, |
1261 | output_gate_scratch, is_input_all_zeros, is_aux_input_all_zeros, |
1262 | is_output_state_all_zeros, compute_row_sums, context, |
1263 | scaling_factors_scratch, recovered_cell_weights, accum_scratch_ptr); |
1264 | // Update the output state. |
1265 | CalculateLstmOutputHybrid( |
1266 | n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch, |
1267 | params->activation, projection_weights_ptr, projection_weights_ledger_ptr, |
1268 | projection_weights_scale, projection_bias_ptr, params->proj_clip, |
1269 | output_state_ptr, asymmetric_quantize_inputs, projection_weights_row_sums, |
1270 | compute_row_sums, context, scratch2, quantized_output_scratch, input_sf, |
1271 | input_zp, accum_scratch_ptr); |
1272 | // Copy output state to the output. Note that the output's rows may not be |
1273 | // contiguous (output_batch_leading_dim != n_output). |
1274 | for (int b = 0; b < n_batch; b++) { |
1275 | std::copy_n(output_state_ptr + b * n_output, n_output, |
1276 | output_ptr + b * output_batch_leading_dim); |
1277 | } |
1278 | } |
1279 | |
1280 | // Fully quantized lstm kernel for 16 bit gate matmul output. |
1281 | // |
1282 | // Input tensor of size n_batch * n_input: |
1283 | // input_ptr |
1284 | // |
1285 | // LSTM weights: |
1286 | // Quantized input weights of size 'n_cell * n_input': |
1287 | // input_to_input_weight_ptr - optional |
1288 | // input_to_forget_weight_ptr - optional |
1289 | // input_to_cell_weight_ptr - optional |
1290 | // input_to_output_weight_ptr - optional |
1291 | // |
1292 | // Quantized recurrent weights of size 'n_cell * n_output': |
1293 | // recurrent_to_input_weight_ptr - optional |
1294 | // recurrent_to_forget_weights_ptr |
1295 | // recurrent_to_cell_weights_ptr |
1296 | // recurrent_to_input_weights_ptr |
1297 | // |
1298 | // Quantized peephole weights of size 'n_cell', representing diagonal matrices. |
1299 | // cell_to_input_weights - optional |
1300 | // cell_to_cell_weights - optional |
1301 | // cell_to_output_weights - optional |
1302 | // |
1303 | // Quantized projection weights of size 'n_output * n_cell' |
1304 | // projection_weight_ptr - optional |
1305 | // |
1306 | // Weight scales (scalars) for each of the weights above. |
1307 | // effective_input_to_input_scale_a - optional |
1308 | // effective_input_to_input_scale_b - optional |
1309 | // effective_input_to_forget_scale_a |
1310 | // effective_input_to_forget_scale_b |
1311 | // effective_input_to_cell_scale_a |
1312 | // effective_input_to_cell_scale_b |
1313 | // effective_input_to_output_scale_a |
1314 | // effective_input_to_output_scale_b |
1315 | // effective_recurrent_to_input_scale_a - optional |
1316 | // effective_recurrent_to_input_scale_b - optional |
1317 | // effective_recurrent_to_forget_scale_a |
1318 | // effective_recurrent_to_forget_scale_b |
1319 | // effective_recurrent_to_cell_scale_a |
1320 | // effective_recurrent_to_cell_scale_b |
1321 | // effective_recurrent_to_output_scale_a |
1322 | // effective_recurrent_to_output_scale_b |
1323 | // effective_proj_scale_a - optional |
1324 | // effective_proj_scale_b - optional |
1325 | // |
1326 | // Gate biases of size 'n_cell': |
1327 | // input_gate_bias_ptr - optional |
1328 | // forget_gate_bias_ptr |
1329 | // cell_gate_bias_ptr |
1330 | // output_gate_bias_ptr |
1331 | // |
1332 | // Layer norm coefficients of size 'n_cell', representing diagonal matrices. |
1333 | // layer_norm_input_weight_ptr - optional |
1334 | // layer_norm_forget_weight_ptr - optional |
1335 | // layer_norm_cell_weight_ptr - optional |
1336 | // layer_norm_output_weight_ptr - optional |
1337 | // |
1338 | // Layer norm scales of size 'n_cell'. |
1339 | // layer_norm_input_scale_a - optional |
1340 | // layer_norm_input_scale_b - optional |
1341 | // layer_norm_forget_scale_a - optional |
1342 | // layer_norm_forget_scale_b - optional |
1343 | // layer_norm_cell_scale_a - optional |
1344 | // layer_norm_cell_scale_b - optional |
1345 | // layer_norm_output_scale_a - optional |
1346 | // layer_norm_output_scale_b - optional |
1347 | // |
1348 | // Scalar values: |
1349 | // quantized_cell_clip: quantized clip value for cell. |
1350 | // quantized_proj_clip: quantized clip value for projection. |
1351 | // cell_state_scale: the power of two scale for cell state. |
1352 | // |
1353 | // Zero points: |
1354 | // output_state_zp: zero point of output state |
1355 | // hidden_zp: zero point for hidden state. |
1356 | // |
1357 | // Temporary pre-allocated storage for the calculation. Each is of size n_cell * |
1358 | // n_batch. |
1359 | // scratch0 |
1360 | // scratch1 |
1361 | // scratch2 |
1362 | // scratch3 |
1363 | // scratch4 |
1364 | // scratch5: this scratch buffer is created purely for optimizing the |
1365 | // MatrixBatchVectorMultiplyAccumulate. |
1366 | // |
1367 | // Outputs: |
1368 | // output_state_ptr - size 'n_batch * n_output' |
1369 | // cell_state_ptr - size 'n_batch * n_cell' |
1370 | // output_ptr - size 'n_batch * n_output' |
1371 | // TODO(b/159947023): scratch0 is not used if (!cifg). Don't allocate then. |
1372 | inline void LstmStepInteger8x8_16( |
1373 | const int8_t* input_ptr, const int8_t* input_to_input_weight_ptr, |
1374 | int32_t effective_input_to_input_scale_a, |
1375 | int32_t effective_input_to_input_scale_b, |
1376 | const int8_t* input_to_forget_weight_ptr, |
1377 | int32_t effective_input_to_forget_scale_a, |
1378 | int32_t effective_input_to_forget_scale_b, |
1379 | const int8_t* input_to_cell_weight_ptr, |
1380 | int32_t effective_input_to_cell_scale_a, |
1381 | int32_t effective_input_to_cell_scale_b, |
1382 | const int8_t* input_to_output_weight_ptr, |
1383 | int32_t effective_input_to_output_scale_a, |
1384 | int32_t effective_input_to_output_scale_b, |
1385 | const int8_t* recurrent_to_input_weight_ptr, |
1386 | int32_t effective_recurrent_to_input_scale_a, |
1387 | int32_t effective_recurrent_to_input_scale_b, |
1388 | const int8_t* recurrent_to_forget_weight_ptr, |
1389 | int32_t effective_recurrent_to_forget_scale_a, |
1390 | int32_t effective_recurrent_to_forget_scale_b, |
1391 | const int8_t* recurrent_to_cell_weight_ptr, |
1392 | int32_t effective_recurrent_to_cell_scale_a, |
1393 | int32_t effective_recurrent_to_cell_scale_b, |
1394 | const int8_t* recurrent_to_output_weight_ptr, |
1395 | int32_t effective_recurrent_to_output_scale_a, |
1396 | int32_t effective_recurrent_to_output_scale_b, |
1397 | const int16_t* cell_to_input_weight_ptr, |
1398 | int32_t effective_cell_to_input_scale_a, |
1399 | int32_t effective_cell_to_input_scale_b, |
1400 | const int16_t* cell_to_forget_weight_ptr, |
1401 | int32_t effective_cell_to_forget_scale_a, |
1402 | int32_t effective_cell_to_forget_scale_b, |
1403 | const int16_t* cell_to_output_weight_ptr, |
1404 | int32_t effective_cell_to_output_scale_a, |
1405 | int32_t effective_cell_to_output_scale_b, |
1406 | const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a, |
1407 | int32_t effective_proj_scale_b, int32_t hidden_zp, |
1408 | int32_t effective_hidden_scale_a, int32_t effective_hidden_scale_b, |
1409 | const int16_t* layer_norm_input_weight_ptr, |
1410 | int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b, |
1411 | const int16_t* layer_norm_forget_weight_ptr, |
1412 | int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b, |
1413 | const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a, |
1414 | int32_t layer_norm_cell_scale_b, |
1415 | const int16_t* layer_norm_output_weight_ptr, |
1416 | int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b, |
1417 | const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr, |
1418 | const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr, |
1419 | int16_t quantized_cell_clip, int8_t quantized_proj_clip, |
1420 | int32_t cell_state_scale, int32_t input_variance_guard, |
1421 | int32_t forget_variance_guard, int32_t cell_variance_guard, |
1422 | int32_t output_variance_guard, |
1423 | const int32_t* input_to_forget_effective_bias, |
1424 | const int32_t* recurrent_to_forget_effective_bias, |
1425 | const int32_t* input_to_cell_effective_bias, |
1426 | const int32_t* recurrent_to_cell_effective_bias, |
1427 | const int32_t* input_to_output_effective_bias, |
1428 | const int32_t* recurrent_to_output_effective_bias, |
1429 | const int32_t* input_to_input_effective_bias, |
1430 | const int32_t* recurrent_to_input_effective_bias, |
1431 | const int32_t* projection_effective_bias, int n_batch, int n_cell, |
1432 | int n_input, int n_output, int8_t* output_state_ptr, |
1433 | int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr, |
1434 | int16_t* scratch0, int16_t* scratch1, int16_t* scratch2, int16_t* scratch3, |
1435 | int8_t* scratch4, int32_t* scratch5, CpuBackendContext* context) { |
1436 | ruy::profiler::ScopeLabel label("LstmStepInteger8x8_16" ); |
1437 | // Make named scratch buffers for the different gates. |
1438 | int16_t* input_gate_scratch = scratch0; |
1439 | int16_t* forget_gate_scratch = scratch1; |
1440 | int16_t* cell_gate_scratch = scratch2; |
1441 | int16_t* output_gate_scratch = scratch3; |
1442 | |
1443 | // Since we have already checked that weights are all there or none, we |
1444 | // can check the existence of only one to the get the condition. |
1445 | const bool use_cifg = (input_to_input_weight_ptr == nullptr); |
1446 | |
1447 | // Check for nullptrs. |
1448 | TFLITE_DCHECK(input_to_forget_effective_bias); |
1449 | TFLITE_DCHECK(recurrent_to_forget_effective_bias); |
1450 | TFLITE_DCHECK(input_to_cell_effective_bias); |
1451 | TFLITE_DCHECK(recurrent_to_cell_effective_bias); |
1452 | TFLITE_DCHECK(input_to_output_effective_bias); |
1453 | TFLITE_DCHECK(recurrent_to_output_effective_bias); |
1454 | if (!use_cifg) { |
1455 | TFLITE_DCHECK(input_to_input_effective_bias); |
1456 | TFLITE_DCHECK(recurrent_to_input_effective_bias); |
1457 | } |
1458 | const bool use_projection = (projection_weight_ptr != nullptr); |
1459 | if (use_projection) { |
1460 | TFLITE_DCHECK(projection_effective_bias); |
1461 | } |
1462 | if (!use_cifg) { |
1463 | // Calculate the input gate. (If not CIFG.) |
1464 | CalculateLstmGateInteger8x8_16( |
1465 | input_ptr, input_to_input_weight_ptr, input_to_input_effective_bias, |
1466 | effective_input_to_input_scale_a, effective_input_to_input_scale_b, |
1467 | output_state_ptr, recurrent_to_input_weight_ptr, |
1468 | recurrent_to_input_effective_bias, effective_recurrent_to_input_scale_a, |
1469 | effective_recurrent_to_input_scale_b, cell_state_ptr, |
1470 | cell_to_input_weight_ptr, effective_cell_to_input_scale_a, |
1471 | effective_cell_to_input_scale_b, layer_norm_input_weight_ptr, |
1472 | input_gate_bias_ptr, layer_norm_input_scale_a, layer_norm_input_scale_b, |
1473 | input_variance_guard, n_batch, n_input, n_output, n_cell, |
1474 | kTfLiteActSigmoid, input_gate_scratch, context, scratch5); |
1475 | } |
1476 | // Calculate the forget gate. |
1477 | CalculateLstmGateInteger8x8_16( |
1478 | input_ptr, input_to_forget_weight_ptr, input_to_forget_effective_bias, |
1479 | effective_input_to_forget_scale_a, effective_input_to_forget_scale_b, |
1480 | output_state_ptr, recurrent_to_forget_weight_ptr, |
1481 | recurrent_to_forget_effective_bias, effective_recurrent_to_forget_scale_a, |
1482 | effective_recurrent_to_forget_scale_b, cell_state_ptr, |
1483 | cell_to_forget_weight_ptr, effective_cell_to_forget_scale_a, |
1484 | effective_cell_to_forget_scale_b, layer_norm_forget_weight_ptr, |
1485 | forget_gate_bias_ptr, layer_norm_forget_scale_a, |
1486 | layer_norm_forget_scale_b, forget_variance_guard, n_batch, n_input, |
1487 | n_output, n_cell, kTfLiteActSigmoid, forget_gate_scratch, context, |
1488 | scratch5); |
1489 | // Calculate the cell update gate. |
1490 | CalculateLstmGateInteger8x8_16( |
1491 | input_ptr, input_to_cell_weight_ptr, input_to_cell_effective_bias, |
1492 | effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, |
1493 | output_state_ptr, recurrent_to_cell_weight_ptr, |
1494 | recurrent_to_cell_effective_bias, effective_recurrent_to_cell_scale_a, |
1495 | effective_recurrent_to_cell_scale_b, cell_state_ptr, |
1496 | /*cell_to_gate_weights=*/nullptr, /*cell_to_gate_scale_a=*/0, |
1497 | /*cell_to_gate_scale_b=*/0, layer_norm_cell_weight_ptr, |
1498 | cell_gate_bias_ptr, layer_norm_cell_scale_a, layer_norm_cell_scale_b, |
1499 | cell_variance_guard, n_batch, n_input, n_output, n_cell, kTfLiteActTanh, |
1500 | cell_gate_scratch, context, scratch5); |
1501 | // Update the cell state. |
1502 | UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, cell_state_scale, |
1503 | input_gate_scratch, forget_gate_scratch, |
1504 | cell_gate_scratch, use_cifg, quantized_cell_clip); |
1505 | // Calculate the output gate. |
1506 | CalculateLstmGateInteger8x8_16( |
1507 | input_ptr, input_to_output_weight_ptr, input_to_output_effective_bias, |
1508 | effective_input_to_output_scale_a, effective_input_to_output_scale_b, |
1509 | output_state_ptr, recurrent_to_output_weight_ptr, |
1510 | recurrent_to_output_effective_bias, effective_recurrent_to_output_scale_a, |
1511 | effective_recurrent_to_output_scale_b, cell_state_ptr, |
1512 | cell_to_output_weight_ptr, effective_cell_to_output_scale_a, |
1513 | effective_cell_to_output_scale_b, layer_norm_output_weight_ptr, |
1514 | output_gate_bias_ptr, layer_norm_output_scale_a, |
1515 | layer_norm_output_scale_b, output_variance_guard, n_batch, n_input, |
1516 | n_output, n_cell, kTfLiteActSigmoid, output_gate_scratch, context, |
1517 | scratch5); |
1518 | // Update the output state. |
1519 | CalculateLstmOutputInteger8x8_16( |
1520 | n_batch, n_cell, n_output, cell_state_ptr, cell_state_scale, |
1521 | output_gate_scratch, effective_hidden_scale_a, effective_hidden_scale_b, |
1522 | hidden_zp, projection_weight_ptr, effective_proj_scale_a, |
1523 | effective_proj_scale_b, projection_effective_bias, output_state_zp, |
1524 | quantized_proj_clip, output_state_ptr, context, scratch0, scratch4, |
1525 | scratch5); |
1526 | // Copy output state to the output. Note that unlike float or hybrid, output |
1527 | // is always contiguous. |
1528 | std::copy_n(output_state_ptr, n_batch * n_output, output_ptr); |
1529 | } |
1530 | |
1531 | // Fully quantized lstm kernel for 8 bit gate matmul output. |
1532 | // |
1533 | // Input tensor of size n_batch * n_input: |
1534 | // input_ptr |
1535 | // |
1536 | // LSTM weights: |
1537 | // Quantized input weights of size 'n_cell * n_input': |
1538 | // input_to_input_weight_ptr - optional |
1539 | // input_to_forget_weight_ptr - optional |
1540 | // input_to_cell_weight_ptr - optional |
1541 | // input_to_output_weight_ptr - optional |
1542 | // |
1543 | // Quantized recurrent weights of size 'n_cell * n_output': |
1544 | // recurrent_to_input_weight_ptr - optional |
1545 | // recurrent_to_forget_weights_ptr |
1546 | // recurrent_to_cell_weights_ptr |
1547 | // recurrent_to_input_weights_ptr |
1548 | // |
1549 | // Quantized peephole weights of size 'n_cell', representing diagonal matrices. |
1550 | // cell_to_input_weights - optional |
1551 | // cell_to_cell_weights - optional |
1552 | // cell_to_output_weights - optional |
1553 | // |
1554 | // Quantized projection weights of size 'n_output * n_cell' |
1555 | // projection_weight_ptr - optional |
1556 | // |
1557 | // Weight scales (scalars) for each of the weights above. |
1558 | // effective_input_to_input_scale_a - optional |
1559 | // effective_input_to_input_scale_b - optional |
1560 | // effective_input_to_forget_scale_a |
1561 | // effective_input_to_forget_scale_b |
1562 | // effective_input_to_cell_scale_a |
1563 | // effective_input_to_cell_scale_b |
1564 | // effective_input_to_output_scale_a |
1565 | // effective_input_to_output_scale_b |
1566 | // effective_recurrent_to_input_scale_a - optional |
1567 | // effective_recurrent_to_input_scale_b - optional |
1568 | // effective_recurrent_to_forget_scale_a |
1569 | // effective_recurrent_to_forget_scale_b |
1570 | // effective_recurrent_to_cell_scale_a |
1571 | // effective_recurrent_to_cell_scale_b |
1572 | // effective_recurrent_to_output_scale_a |
1573 | // effective_recurrent_to_output_scale_b |
1574 | // effective_proj_scale_a - optional |
1575 | // effective_proj_scale_b - optional |
1576 | // |
1577 | // Gate biases of size 'n_cell': |
1578 | // input_gate_bias_ptr - optional |
1579 | // forget_gate_bias_ptr |
1580 | // cell_gate_bias_ptr |
1581 | // output_gate_bias_ptr |
1582 | // |
1583 | // Layer norm coefficients of size 'n_cell', representing diagonal matrices. |
1584 | // layer_norm_input_weight_ptr - optional |
1585 | // layer_norm_forget_weight_ptr - optional |
1586 | // layer_norm_cell_weight_ptr - optional |
1587 | // layer_norm_output_weight_ptr - optional |
1588 | // |
1589 | // Layer norm scales of size 'n_cell'. |
1590 | // layer_norm_input_scale_a - optional |
1591 | // layer_norm_input_scale_b - optional |
1592 | // layer_norm_forget_scale_a - optional |
1593 | // layer_norm_forget_scale_b - optional |
1594 | // layer_norm_cell_scale_a - optional |
1595 | // layer_norm_cell_scale_b - optional |
1596 | // layer_norm_output_scale_a - optional |
1597 | // layer_norm_output_scale_b - optional |
1598 | // |
1599 | // Scalar values: |
1600 | // quantized_cell_clip: quantized clip value for cell. |
1601 | // quantized_proj_clip: quantized clip value for projection. |
1602 | // cell_state_scale: the power of two scale for cell state. |
1603 | // |
1604 | // Zero points: |
1605 | // input_zp: zero point for input tensor. |
1606 | // output_state_zp: zero point of output state. |
1607 | // hidden_zp: zero point for hidden state. |
1608 | // |
1609 | // Temporary pre-allocated storage for the calculation. Each is of size n_cell * |
1610 | // n_batch. |
1611 | // scratch0 |
1612 | // scratch1 |
1613 | // scratch2 |
1614 | // scratch3 |
1615 | // scratch4 |
1616 | // scratch5 |
1617 | // scratch6 |
1618 | // scratch7 |
1619 | // |
1620 | // Outputs: |
1621 | // output_state_ptr - size 'n_batch * n_output' |
1622 | // cell_state_ptr - size 'n_batch * n_cell' |
1623 | // output_ptr - size 'n_batch * n_output' |
1624 | // |
1625 | // Can move zero point calculation into Prepare() for better perfomance. |
1626 | // TODO(b/159947023): scratch5 is unused, remove. |
1627 | inline void LstmStepInteger8x8_8( |
1628 | const int8_t* input_ptr, int32_t input_zp, |
1629 | const int8_t* input_to_input_weight_ptr, |
1630 | int32_t effective_input_to_input_scale_a, |
1631 | int32_t effective_input_to_input_scale_b, |
1632 | const int8_t* input_to_forget_weight_ptr, |
1633 | int32_t effective_input_to_forget_scale_a, |
1634 | int32_t effective_input_to_forget_scale_b, |
1635 | const int8_t* input_to_cell_weight_ptr, |
1636 | int32_t effective_input_to_cell_scale_a, |
1637 | int32_t effective_input_to_cell_scale_b, |
1638 | const int8_t* input_to_output_weight_ptr, |
1639 | int32_t effective_input_to_output_scale_a, |
1640 | int32_t effective_input_to_output_scale_b, |
1641 | const int8_t* recurrent_to_input_weight_ptr, |
1642 | int32_t effective_recurrent_to_input_scale_a, |
1643 | int32_t effective_recurrent_to_input_scale_b, |
1644 | const int8_t* recurrent_to_forget_weight_ptr, |
1645 | int32_t effective_recurrent_to_forget_scale_a, |
1646 | int32_t effective_recurrent_to_forget_scale_b, |
1647 | const int8_t* recurrent_to_cell_weight_ptr, |
1648 | int32_t effective_recurrent_to_cell_scale_a, |
1649 | int32_t effective_recurrent_to_cell_scale_b, |
1650 | const int8_t* recurrent_to_output_weight_ptr, |
1651 | int32_t effective_recurrent_to_output_scale_a, |
1652 | int32_t effective_recurrent_to_output_scale_b, |
1653 | const int8_t* cell_to_input_weight_ptr, |
1654 | int32_t effective_cell_to_input_scale_a, |
1655 | int32_t effective_cell_to_input_scale_b, |
1656 | const int8_t* cell_to_forget_weight_ptr, |
1657 | int32_t effective_cell_to_forget_scale_a, |
1658 | int32_t effective_cell_to_forget_scale_b, |
1659 | const int8_t* cell_to_output_weight_ptr, |
1660 | int32_t effective_cell_to_output_scale_a, |
1661 | int32_t effective_cell_to_output_scale_b, |
1662 | const int8_t* projection_weight_ptr, int32_t effective_proj_scale_a, |
1663 | int32_t effective_proj_scale_b, const int16_t* layer_norm_input_weight_ptr, |
1664 | int32_t layer_norm_input_scale_a, int32_t layer_norm_input_scale_b, |
1665 | const int16_t* layer_norm_forget_weight_ptr, |
1666 | int32_t layer_norm_forget_scale_a, int32_t layer_norm_forget_scale_b, |
1667 | const int16_t* layer_norm_cell_weight_ptr, int32_t layer_norm_cell_scale_a, |
1668 | int32_t layer_norm_cell_scale_b, |
1669 | const int16_t* layer_norm_output_weight_ptr, |
1670 | int32_t layer_norm_output_scale_a, int32_t layer_norm_output_scale_b, |
1671 | const int32_t* input_gate_bias_ptr, const int32_t* forget_gate_bias_ptr, |
1672 | const int32_t* cell_gate_bias_ptr, const int32_t* output_gate_bias_ptr, |
1673 | const int32_t* projection_bias_ptr, const TfLiteLSTMParams* params, |
1674 | const int32_t* intermediate_scale_a, const int32_t* intermediate_scale_b, |
1675 | const int32_t* intermediate_zp, int16_t quantized_cell_clip, |
1676 | int8_t quantized_proj_clip, int n_batch, int n_cell, int n_input, |
1677 | int n_output, int output_batch_leading_dim, int8_t* output_state_ptr, |
1678 | int32_t output_state_zp, int16_t* cell_state_ptr, int8_t* output_ptr, |
1679 | int8_t* scratch0, int8_t* scratch1, int16_t* scratch2, int16_t* scratch3, |
1680 | int16_t* scratch4, int16_t* scratch5, int16_t* scratch6, |
1681 | int16_t* scratch7) { |
1682 | // TODO(b/159066113): scratch5 is unused, remove. |
1683 | |
1684 | ruy::profiler::ScopeLabel label("LstmStepInteger8x8_8" ); |
1685 | // Make named scratch buffers for the different gates. |
1686 | int16_t* forget_gate_scratch = scratch2; |
1687 | int16_t* cell_gate_scratch = scratch3; |
1688 | int16_t* output_gate_scratch = scratch4; |
1689 | // no-CIFG is not supported here |
1690 | |
1691 | // Calculate the forget gate. |
1692 | CalculateLstmGateInteger8x8_8( |
1693 | input_ptr, input_zp, input_to_forget_weight_ptr, |
1694 | effective_input_to_forget_scale_a, effective_input_to_forget_scale_b, |
1695 | intermediate_scale_a[2], intermediate_scale_b[2], intermediate_zp[4], |
1696 | output_state_ptr, output_state_zp, recurrent_to_forget_weight_ptr, |
1697 | effective_recurrent_to_forget_scale_a, |
1698 | effective_recurrent_to_forget_scale_b, intermediate_scale_a[3], |
1699 | intermediate_scale_b[3], intermediate_zp[5], layer_norm_forget_weight_ptr, |
1700 | layer_norm_forget_scale_a, layer_norm_forget_scale_b, |
1701 | forget_gate_bias_ptr, n_batch, n_input, n_output, n_cell, |
1702 | kTfLiteActSigmoid, forget_gate_scratch, scratch0, scratch1); |
1703 | // Calculate the cell update gate. |
1704 | CalculateLstmGateInteger8x8_8( |
1705 | input_ptr, input_zp, input_to_cell_weight_ptr, |
1706 | effective_input_to_cell_scale_a, effective_input_to_cell_scale_b, |
1707 | intermediate_scale_a[4], intermediate_scale_b[4], intermediate_zp[7], |
1708 | output_state_ptr, output_state_zp, recurrent_to_cell_weight_ptr, |
1709 | effective_recurrent_to_cell_scale_a, effective_recurrent_to_cell_scale_b, |
1710 | intermediate_scale_a[5], intermediate_scale_b[5], intermediate_zp[8], |
1711 | layer_norm_cell_weight_ptr, layer_norm_cell_scale_a, |
1712 | layer_norm_cell_scale_b, cell_gate_bias_ptr, n_batch, n_input, n_output, |
1713 | n_cell, kTfLiteActTanh, cell_gate_scratch, scratch0, scratch1); |
1714 | // Update the cell state. |
1715 | UpdateLstmCellInteger(n_batch, n_cell, cell_state_ptr, |
1716 | /*cell_state_scale=*/-15, /*input_gate=*/nullptr, |
1717 | forget_gate_scratch, cell_gate_scratch, |
1718 | /*use_cifg=*/true, quantized_cell_clip); |
1719 | // Calculate the output gate. |
1720 | CalculateLstmGateInteger8x8_8( |
1721 | input_ptr, input_zp, input_to_output_weight_ptr, |
1722 | effective_input_to_output_scale_a, effective_input_to_output_scale_b, |
1723 | intermediate_scale_a[6], intermediate_scale_b[6], intermediate_zp[10], |
1724 | output_state_ptr, output_state_zp, recurrent_to_output_weight_ptr, |
1725 | effective_recurrent_to_output_scale_a, |
1726 | effective_recurrent_to_output_scale_b, intermediate_scale_a[11], |
1727 | intermediate_scale_b[7], intermediate_zp[7], layer_norm_output_weight_ptr, |
1728 | layer_norm_output_scale_a, layer_norm_output_scale_b, |
1729 | output_gate_bias_ptr, n_batch, n_input, n_output, n_cell, |
1730 | kTfLiteActSigmoid, output_gate_scratch, scratch0, scratch1); |
1731 | // Update the output state. |
1732 | CalculateLstmOutputInteger8x8_8( |
1733 | n_batch, n_cell, n_output, cell_state_ptr, output_gate_scratch, |
1734 | projection_weight_ptr, effective_proj_scale_a, effective_proj_scale_b, |
1735 | projection_bias_ptr, output_state_zp, quantized_proj_clip, |
1736 | output_state_ptr, scratch2); |
1737 | // Copy output state to the output. Note that unlike float or hybrid, output |
1738 | // is always contigous. |
1739 | std::copy_n(output_state_ptr, n_batch * n_output, output_ptr); |
1740 | } |
1741 | |
1742 | } // namespace |
1743 | |
1744 | // LINT.IfChange |
1745 | TfLiteStatus EvalFloat( |
1746 | const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, |
1747 | const TfLiteTensor* input_to_forget_weights, |
1748 | const TfLiteTensor* input_to_cell_weights, |
1749 | const TfLiteTensor* input_to_output_weights, |
1750 | const TfLiteTensor* recurrent_to_input_weights, |
1751 | const TfLiteTensor* recurrent_to_forget_weights, |
1752 | const TfLiteTensor* recurrent_to_cell_weights, |
1753 | const TfLiteTensor* recurrent_to_output_weights, |
1754 | const TfLiteTensor* cell_to_input_weights, |
1755 | const TfLiteTensor* cell_to_forget_weights, |
1756 | const TfLiteTensor* cell_to_output_weights, |
1757 | const TfLiteTensor* input_layer_norm_coefficients, |
1758 | const TfLiteTensor* forget_layer_norm_coefficients, |
1759 | const TfLiteTensor* cell_layer_norm_coefficients, |
1760 | const TfLiteTensor* output_layer_norm_coefficients, |
1761 | const TfLiteTensor* aux_input, |
1762 | const TfLiteTensor* aux_input_to_input_weights, |
1763 | const TfLiteTensor* aux_input_to_forget_weights, |
1764 | const TfLiteTensor* aux_input_to_cell_weights, |
1765 | const TfLiteTensor* aux_input_to_output_weights, |
1766 | const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, |
1767 | const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, |
1768 | const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, |
1769 | const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, |
1770 | int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* output_state, |
1771 | TfLiteTensor* cell_state, TfLiteTensor* output, |
1772 | CpuBackendContext* context) { |
1773 | TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); |
1774 | int max_time, n_batch; |
1775 | if (input->dims->size == 3) { |
1776 | max_time = (time_major) ? input->dims->data[0] : input->dims->data[1]; |
1777 | n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0]; |
1778 | } else { |
1779 | max_time = 1; |
1780 | n_batch = input->dims->data[0]; |
1781 | } |
1782 | const int n_input = input->dims->data[input->dims->size - 1]; |
1783 | const int aux_input_size = |
1784 | (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0; |
1785 | |
1786 | // n_cell and n_output will be the same size when there is no projection. |
1787 | const int n_cell = input_to_output_weights->dims->data[0]; |
1788 | const int n_output = recurrent_to_output_weights->dims->data[1]; |
1789 | |
1790 | // Since we have already checked that weights are all there or none, we can |
1791 | // check the existence of only one to the get the condition. |
1792 | const bool use_cifg = (input_to_input_weights == nullptr); |
1793 | |
1794 | // Index the scratch buffers pointers to the global scratch buffer. |
1795 | float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer); |
1796 | float* input_gate_scratch = nullptr; |
1797 | float* cell_gate_scratch = nullptr; |
1798 | float* forget_gate_scratch = nullptr; |
1799 | float* output_gate_scratch = nullptr; |
1800 | float* accumulation_scratch_buffer = nullptr; |
1801 | if (use_cifg) { |
1802 | cell_gate_scratch = scratch_buffer_ptr; |
1803 | forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch; |
1804 | output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch; |
1805 | accumulation_scratch_buffer = scratch_buffer_ptr + 3 * n_cell * n_batch; |
1806 | } else { |
1807 | input_gate_scratch = scratch_buffer_ptr; |
1808 | cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch; |
1809 | forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch; |
1810 | output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch; |
1811 | accumulation_scratch_buffer = scratch_buffer_ptr + 4 * n_cell * n_batch; |
1812 | } |
1813 | |
1814 | const int output_batch_leading_dim = |
1815 | output->dims->data[output->dims->size - 1]; |
1816 | if (time_major) { |
1817 | // Loop through the sequence. |
1818 | const int input_step = n_batch * n_input; |
1819 | const int output_step = n_batch * output_batch_leading_dim; |
1820 | for (int t = 0; t < max_time; t++) { |
1821 | // If this is the forward_sequence, step forward, otherwise step |
1822 | // backwards. |
1823 | const int t_rel = forward_sequence ? t : max_time - t - 1; |
1824 | const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step; |
1825 | const float* aux_input_ptr = nullptr; |
1826 | if (aux_input) { |
1827 | aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step; |
1828 | } |
1829 | float* output_ptr = |
1830 | GetTensorData<float>(output) + t_rel * output_step + output_offset; |
1831 | |
1832 | LstmStepFloat( |
1833 | input_ptr, GetTensorData<float>(input_to_input_weights), |
1834 | GetTensorData<float>(input_to_forget_weights), |
1835 | GetTensorData<float>(input_to_cell_weights), |
1836 | GetTensorData<float>(input_to_output_weights), aux_input_ptr, |
1837 | GetTensorData<float>(aux_input_to_input_weights), |
1838 | GetTensorData<float>(aux_input_to_forget_weights), |
1839 | GetTensorData<float>(aux_input_to_cell_weights), |
1840 | GetTensorData<float>(aux_input_to_output_weights), |
1841 | GetTensorData<float>(recurrent_to_input_weights), |
1842 | GetTensorData<float>(recurrent_to_forget_weights), |
1843 | GetTensorData<float>(recurrent_to_cell_weights), |
1844 | GetTensorData<float>(recurrent_to_output_weights), |
1845 | GetTensorData<float>(cell_to_input_weights), |
1846 | GetTensorData<float>(cell_to_forget_weights), |
1847 | GetTensorData<float>(cell_to_output_weights), |
1848 | GetTensorData<float>(input_layer_norm_coefficients), |
1849 | GetTensorData<float>(forget_layer_norm_coefficients), |
1850 | GetTensorData<float>(cell_layer_norm_coefficients), |
1851 | GetTensorData<float>(output_layer_norm_coefficients), |
1852 | GetTensorData<float>(input_gate_bias), |
1853 | GetTensorData<float>(forget_gate_bias), |
1854 | GetTensorData<float>(cell_gate_bias), |
1855 | GetTensorData<float>(output_gate_bias), |
1856 | GetTensorData<float>(projection_weights), |
1857 | GetTensorData<float>(projection_bias), params, n_batch, n_cell, |
1858 | n_input, aux_input_size, n_output, output_batch_leading_dim, |
1859 | GetTensorData<float>(output_state), GetTensorData<float>(cell_state), |
1860 | input_gate_scratch, forget_gate_scratch, cell_gate_scratch, |
1861 | output_gate_scratch, accumulation_scratch_buffer, output_ptr, |
1862 | context); |
1863 | } |
1864 | } else { |
1865 | for (int b = 0; b < n_batch; b++) { |
1866 | const int input_step = n_input; |
1867 | const int output_step = output_batch_leading_dim; |
1868 | for (int t = 0; t < max_time; t++) { |
1869 | // If this is the forward_sequence, step forward, otherwise step |
1870 | // backwards. |
1871 | const int t_rel = forward_sequence ? t : max_time - t - 1; |
1872 | const int time_offset = b * max_time + t_rel; |
1873 | const float* input_ptr = |
1874 | GetTensorData<float>(input) + time_offset * input_step; |
1875 | const float* aux_input_ptr = nullptr; |
1876 | if (aux_input) { |
1877 | aux_input_ptr = |
1878 | GetTensorData<float>(aux_input) + time_offset * input_step; |
1879 | } |
1880 | float* output_ptr = GetTensorData<float>(output) + |
1881 | time_offset * output_step + output_offset; |
1882 | |
1883 | // Offset the {output,cell}_state pointers to the right batch. |
1884 | float* output_state_ptr = |
1885 | GetTensorData<float>(output_state) + b * output_batch_leading_dim; |
1886 | float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell; |
1887 | // Offset the scratch pointers to the right batch. |
1888 | float* input_gate_scratch_ptr = |
1889 | input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr; |
1890 | float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell; |
1891 | float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell; |
1892 | float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell; |
1893 | |
1894 | LstmStepFloat( |
1895 | input_ptr, GetTensorData<float>(input_to_input_weights), |
1896 | GetTensorData<float>(input_to_forget_weights), |
1897 | GetTensorData<float>(input_to_cell_weights), |
1898 | GetTensorData<float>(input_to_output_weights), aux_input_ptr, |
1899 | GetTensorData<float>(aux_input_to_input_weights), |
1900 | GetTensorData<float>(aux_input_to_forget_weights), |
1901 | GetTensorData<float>(aux_input_to_cell_weights), |
1902 | GetTensorData<float>(aux_input_to_output_weights), |
1903 | GetTensorData<float>(recurrent_to_input_weights), |
1904 | GetTensorData<float>(recurrent_to_forget_weights), |
1905 | GetTensorData<float>(recurrent_to_cell_weights), |
1906 | GetTensorData<float>(recurrent_to_output_weights), |
1907 | GetTensorData<float>(cell_to_input_weights), |
1908 | GetTensorData<float>(cell_to_forget_weights), |
1909 | GetTensorData<float>(cell_to_output_weights), |
1910 | GetTensorData<float>(input_layer_norm_coefficients), |
1911 | GetTensorData<float>(forget_layer_norm_coefficients), |
1912 | GetTensorData<float>(cell_layer_norm_coefficients), |
1913 | GetTensorData<float>(output_layer_norm_coefficients), |
1914 | GetTensorData<float>(input_gate_bias), |
1915 | GetTensorData<float>(forget_gate_bias), |
1916 | GetTensorData<float>(cell_gate_bias), |
1917 | GetTensorData<float>(output_gate_bias), |
1918 | GetTensorData<float>(projection_weights), |
1919 | GetTensorData<float>(projection_bias), params, /*n_batch=*/1, |
1920 | n_cell, n_input, aux_input_size, n_output, output_batch_leading_dim, |
1921 | output_state_ptr, cell_state_ptr, input_gate_scratch_ptr, |
1922 | forget_gate_scratch_ptr, cell_gate_scratch_ptr, |
1923 | output_gate_scratch_ptr, accumulation_scratch_buffer, output_ptr, |
1924 | context); |
1925 | } |
1926 | } |
1927 | } |
1928 | return kTfLiteOk; |
1929 | } |
1930 | // LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc) |
1931 | |
1932 | TfLiteStatus EvalHybrid( |
1933 | const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, |
1934 | const TfLiteTensor* input_to_input_weights_ledger, |
1935 | const TfLiteTensor* input_to_forget_weights, |
1936 | const TfLiteTensor* input_to_forget_weights_ledger, |
1937 | const TfLiteTensor* input_to_cell_weights, |
1938 | const TfLiteTensor* input_to_cell_weights_ledger, |
1939 | const TfLiteTensor* input_to_output_weights, |
1940 | const TfLiteTensor* input_to_output_weights_ledger, |
1941 | const TfLiteTensor* recurrent_to_input_weights, |
1942 | const TfLiteTensor* recurrent_to_input_weights_ledger, |
1943 | const TfLiteTensor* recurrent_to_forget_weights, |
1944 | const TfLiteTensor* recurrent_to_forget_weights_ledger, |
1945 | const TfLiteTensor* recurrent_to_cell_weights, |
1946 | const TfLiteTensor* recurrent_to_cell_weights_ledger, |
1947 | const TfLiteTensor* recurrent_to_output_weights, |
1948 | const TfLiteTensor* recurrent_to_output_weights_ledger, |
1949 | const TfLiteTensor* cell_to_input_weights, |
1950 | const TfLiteTensor* cell_to_forget_weights, |
1951 | const TfLiteTensor* cell_to_output_weights, |
1952 | const TfLiteTensor* input_layer_norm_coefficients, |
1953 | const TfLiteTensor* forget_layer_norm_coefficients, |
1954 | const TfLiteTensor* cell_layer_norm_coefficients, |
1955 | const TfLiteTensor* output_layer_norm_coefficients, |
1956 | const TfLiteTensor* aux_input, |
1957 | const TfLiteTensor* aux_input_to_input_weights, |
1958 | const TfLiteTensor* aux_input_to_forget_weights, |
1959 | const TfLiteTensor* aux_input_to_cell_weights, |
1960 | const TfLiteTensor* aux_input_to_output_weights, |
1961 | const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, |
1962 | const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, |
1963 | const TfLiteTensor* projection_weights, |
1964 | const TfLiteTensor* projection_weights_ledger, |
1965 | const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params, |
1966 | bool forward_sequence, bool time_major, int output_offset, |
1967 | TfLiteTensor* scratch_buffer, TfLiteTensor* input_sf, |
1968 | TfLiteTensor* aux_input_sf, TfLiteTensor* output_state_sf, |
1969 | TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, |
1970 | TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, |
1971 | TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, |
1972 | TfLiteTensor* output_state, TfLiteTensor* cell_state, |
1973 | TfLiteTensor* output_scratch_buffer, TfLiteTensor* output, |
1974 | TfLiteTensor* input_zp, TfLiteTensor* aux_input_zp, |
1975 | TfLiteTensor* output_state_zp, TfLiteTensor* row_sums, int row_sums_size, |
1976 | bool* compute_row_sums, CpuBackendContext* context) { |
1977 | TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); |
1978 | const int n_input = input->dims->data[input->dims->size - 1]; |
1979 | int max_time, n_batch; |
1980 | if (input->dims->size == 2) { |
1981 | max_time = 1; |
1982 | n_batch = input->dims->data[0]; |
1983 | } else { |
1984 | max_time = (time_major) ? input->dims->data[0] : input->dims->data[1]; |
1985 | n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0]; |
1986 | } |
1987 | const int aux_input_size = |
1988 | (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0; |
1989 | // n_cell and n_output will be the same size when there is no projection. |
1990 | const int n_cell = input_to_output_weights->dims->data[0]; |
1991 | const int n_output = recurrent_to_output_weights->dims->data[1]; |
1992 | |
1993 | // Since we have already checked that weights are all there or none, we can |
1994 | // check the existence of only one to get the condition. |
1995 | const bool use_cifg = (input_to_input_weights == nullptr); |
1996 | |
1997 | float* scratch_buffer_ptr = GetTensorData<float>(scratch_buffer); |
1998 | float* input_gate_scratch = nullptr; |
1999 | float* cell_gate_scratch = nullptr; |
2000 | float* forget_gate_scratch = nullptr; |
2001 | float* output_gate_scratch = nullptr; |
2002 | if (use_cifg) { |
2003 | cell_gate_scratch = scratch_buffer_ptr; |
2004 | forget_gate_scratch = scratch_buffer_ptr + n_cell * n_batch; |
2005 | output_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch; |
2006 | } else { |
2007 | input_gate_scratch = scratch_buffer_ptr; |
2008 | cell_gate_scratch = scratch_buffer_ptr + n_cell * n_batch; |
2009 | forget_gate_scratch = scratch_buffer_ptr + 2 * n_cell * n_batch; |
2010 | output_gate_scratch = scratch_buffer_ptr + 3 * n_cell * n_batch; |
2011 | } |
2012 | |
2013 | const int output_batch_leading_dim = |
2014 | output->dims->data[output->dims->size - 1]; |
2015 | |
2016 | int32_t* input_zp_ptr = nullptr; |
2017 | int32_t* aux_input_zp_ptr = nullptr; |
2018 | int32_t* output_state_zp_ptr = nullptr; |
2019 | int32_t* row_sums_ptr = nullptr; |
2020 | if (params->asymmetric_quantize_inputs) { |
2021 | input_zp_ptr = GetTensorData<int32_t>(input_zp); |
2022 | aux_input_zp_ptr = GetTensorData<int32_t>(aux_input_zp); |
2023 | output_state_zp_ptr = GetTensorData<int32_t>(output_state_zp); |
2024 | row_sums_ptr = GetTensorData<int32_t>(row_sums); |
2025 | } |
2026 | |
2027 | if (time_major) { |
2028 | // Feed the sequence into the LSTM step-by-step. |
2029 | const int input_step = n_batch * n_input; |
2030 | const int output_step = n_batch * output_batch_leading_dim; |
2031 | for (int t = 0; t < max_time; t++) { |
2032 | // If this is the forward_sequence, step forward, otherwise step |
2033 | // backwards. |
2034 | const int t_rel = forward_sequence ? t : max_time - t - 1; |
2035 | const float* input_ptr = GetTensorData<float>(input) + t_rel * input_step; |
2036 | const float* aux_input_ptr = nullptr; |
2037 | if (aux_input) { |
2038 | aux_input_ptr = GetTensorData<float>(aux_input) + t_rel * input_step; |
2039 | } |
2040 | float* output_ptr = |
2041 | GetTensorData<float>(output) + t_rel * output_step + output_offset; |
2042 | LstmStepHybrid( |
2043 | input_ptr, GetTensorData<int8_t>(input_to_input_weights), |
2044 | GetTensorData<uint8_t>(input_to_input_weights_ledger), |
2045 | GetTensorScale(input_to_input_weights), |
2046 | GetTensorData<int8_t>(input_to_forget_weights), |
2047 | GetTensorData<uint8_t>(input_to_forget_weights_ledger), |
2048 | GetTensorScale(input_to_forget_weights), |
2049 | GetTensorData<int8_t>(input_to_cell_weights), |
2050 | GetTensorData<uint8_t>(input_to_cell_weights_ledger), |
2051 | GetTensorScale(input_to_cell_weights), |
2052 | GetTensorData<int8_t>(input_to_output_weights), |
2053 | GetTensorData<uint8_t>(input_to_output_weights_ledger), |
2054 | GetTensorScale(input_to_output_weights), aux_input_ptr, |
2055 | GetTensorData<int8_t>(aux_input_to_input_weights), |
2056 | GetTensorScale(aux_input_to_input_weights), |
2057 | GetTensorData<int8_t>(aux_input_to_forget_weights), |
2058 | GetTensorScale(aux_input_to_forget_weights), |
2059 | GetTensorData<int8_t>(aux_input_to_cell_weights), |
2060 | GetTensorScale(aux_input_to_cell_weights), |
2061 | GetTensorData<int8_t>(aux_input_to_output_weights), |
2062 | GetTensorScale(aux_input_to_output_weights), |
2063 | GetTensorData<int8_t>(recurrent_to_input_weights), |
2064 | GetTensorData<uint8_t>(recurrent_to_input_weights_ledger), |
2065 | GetTensorScale(recurrent_to_input_weights), |
2066 | GetTensorData<int8_t>(recurrent_to_forget_weights), |
2067 | GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger), |
2068 | GetTensorScale(recurrent_to_forget_weights), |
2069 | GetTensorData<int8_t>(recurrent_to_cell_weights), |
2070 | GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger), |
2071 | GetTensorScale(recurrent_to_cell_weights), |
2072 | GetTensorData<int8_t>(recurrent_to_output_weights), |
2073 | GetTensorData<uint8_t>(recurrent_to_output_weights_ledger), |
2074 | GetTensorScale(recurrent_to_output_weights), |
2075 | GetTensorData<int8_t>(cell_to_input_weights), |
2076 | GetTensorScale(cell_to_input_weights), |
2077 | GetTensorData<int8_t>(cell_to_forget_weights), |
2078 | GetTensorScale(cell_to_forget_weights), |
2079 | GetTensorData<int8_t>(cell_to_output_weights), |
2080 | GetTensorScale(cell_to_output_weights), |
2081 | GetTensorData<float>(input_layer_norm_coefficients), |
2082 | GetTensorData<float>(forget_layer_norm_coefficients), |
2083 | GetTensorData<float>(cell_layer_norm_coefficients), |
2084 | GetTensorData<float>(output_layer_norm_coefficients), |
2085 | GetTensorData<float>(input_gate_bias), |
2086 | GetTensorData<float>(forget_gate_bias), |
2087 | GetTensorData<float>(cell_gate_bias), |
2088 | GetTensorData<float>(output_gate_bias), |
2089 | GetTensorData<int8_t>(projection_weights), |
2090 | GetTensorData<uint8_t>(projection_weights_ledger), |
2091 | GetTensorScale(projection_weights), |
2092 | GetTensorData<float>(projection_bias), params, n_batch, n_cell, |
2093 | n_input, aux_input_size, n_output, output_batch_leading_dim, |
2094 | input_gate_scratch, forget_gate_scratch, cell_gate_scratch, |
2095 | output_gate_scratch, GetTensorData<float>(input_sf), |
2096 | GetTensorData<float>(aux_input_sf), |
2097 | GetTensorData<float>(output_state_sf), |
2098 | GetTensorData<float>(prod_scaling_factors), |
2099 | GetTensorData<float>(recovered_cell_weights), |
2100 | GetTensorData<int8_t>(input_quantized), |
2101 | GetTensorData<int8_t>(aux_input_quantized), |
2102 | GetTensorData<int8_t>(output_state_quantized), |
2103 | GetTensorData<int8_t>(cell_state_quantized), |
2104 | GetTensorData<float>(output_state), GetTensorData<float>(cell_state), |
2105 | GetTensorData<int32_t>(output_scratch_buffer), output_ptr, |
2106 | input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr, row_sums_ptr, |
2107 | row_sums_size, compute_row_sums, params->asymmetric_quantize_inputs, |
2108 | context); |
2109 | } |
2110 | } else { |
2111 | for (int b = 0; b < n_batch; b++) { |
2112 | const int input_step = n_input; |
2113 | const int output_step = output_batch_leading_dim; |
2114 | for (int t = 0; t < max_time; t++) { |
2115 | // If this is the forward_sequence, step forward, otherwise step |
2116 | // backwards. |
2117 | const int t_rel = forward_sequence ? t : max_time - t - 1; |
2118 | const int time_offset = b * max_time + t_rel; |
2119 | const float* input_ptr = |
2120 | GetTensorData<float>(input) + time_offset * input_step; |
2121 | const float* aux_input_ptr = nullptr; |
2122 | if (aux_input) { |
2123 | aux_input_ptr = |
2124 | GetTensorData<float>(aux_input) + time_offset * input_step; |
2125 | } |
2126 | float* output_ptr = GetTensorData<float>(output) + |
2127 | time_offset * output_step + output_offset; |
2128 | |
2129 | // Offset the {output,cell}_state pointers to the right batch. |
2130 | float* output_state_ptr = |
2131 | GetTensorData<float>(output_state) + b * output_batch_leading_dim; |
2132 | float* cell_state_ptr = GetTensorData<float>(cell_state) + b * n_cell; |
2133 | // Offset the scratch pointers to the right batch. |
2134 | float* input_gate_scratch_ptr = |
2135 | input_gate_scratch ? input_gate_scratch + b * n_cell : nullptr; |
2136 | float* forget_gate_scratch_ptr = forget_gate_scratch + b * n_cell; |
2137 | float* cell_gate_scratch_ptr = cell_gate_scratch + b * n_cell; |
2138 | float* output_gate_scratch_ptr = output_gate_scratch + b * n_cell; |
2139 | |
2140 | LstmStepHybrid( |
2141 | input_ptr, GetTensorData<int8_t>(input_to_input_weights), |
2142 | GetTensorData<uint8_t>(input_to_input_weights_ledger), |
2143 | GetTensorScale(input_to_input_weights), |
2144 | GetTensorData<int8_t>(input_to_forget_weights), |
2145 | GetTensorData<uint8_t>(input_to_forget_weights_ledger), |
2146 | GetTensorScale(input_to_forget_weights), |
2147 | GetTensorData<int8_t>(input_to_cell_weights), |
2148 | GetTensorData<uint8_t>(input_to_cell_weights_ledger), |
2149 | GetTensorScale(input_to_cell_weights), |
2150 | GetTensorData<int8_t>(input_to_output_weights), |
2151 | GetTensorData<uint8_t>(input_to_output_weights_ledger), |
2152 | GetTensorScale(input_to_output_weights), aux_input_ptr, |
2153 | GetTensorData<int8_t>(aux_input_to_input_weights), |
2154 | GetTensorScale(aux_input_to_input_weights), |
2155 | GetTensorData<int8_t>(aux_input_to_forget_weights), |
2156 | GetTensorScale(aux_input_to_forget_weights), |
2157 | GetTensorData<int8_t>(aux_input_to_cell_weights), |
2158 | GetTensorScale(aux_input_to_cell_weights), |
2159 | GetTensorData<int8_t>(aux_input_to_output_weights), |
2160 | GetTensorScale(aux_input_to_output_weights), |
2161 | GetTensorData<int8_t>(recurrent_to_input_weights), |
2162 | GetTensorData<uint8_t>(recurrent_to_input_weights_ledger), |
2163 | GetTensorScale(recurrent_to_input_weights), |
2164 | GetTensorData<int8_t>(recurrent_to_forget_weights), |
2165 | GetTensorData<uint8_t>(recurrent_to_forget_weights_ledger), |
2166 | GetTensorScale(recurrent_to_forget_weights), |
2167 | GetTensorData<int8_t>(recurrent_to_cell_weights), |
2168 | GetTensorData<uint8_t>(recurrent_to_cell_weights_ledger), |
2169 | GetTensorScale(recurrent_to_cell_weights), |
2170 | GetTensorData<int8_t>(recurrent_to_output_weights), |
2171 | GetTensorData<uint8_t>(recurrent_to_output_weights_ledger), |
2172 | GetTensorScale(recurrent_to_output_weights), |
2173 | GetTensorData<int8_t>(cell_to_input_weights), |
2174 | GetTensorScale(cell_to_input_weights), |
2175 | GetTensorData<int8_t>(cell_to_forget_weights), |
2176 | GetTensorScale(cell_to_forget_weights), |
2177 | GetTensorData<int8_t>(cell_to_output_weights), |
2178 | GetTensorScale(cell_to_output_weights), |
2179 | GetTensorData<float>(input_layer_norm_coefficients), |
2180 | GetTensorData<float>(forget_layer_norm_coefficients), |
2181 | GetTensorData<float>(cell_layer_norm_coefficients), |
2182 | GetTensorData<float>(output_layer_norm_coefficients), |
2183 | GetTensorData<float>(input_gate_bias), |
2184 | GetTensorData<float>(forget_gate_bias), |
2185 | GetTensorData<float>(cell_gate_bias), |
2186 | GetTensorData<float>(output_gate_bias), |
2187 | GetTensorData<int8_t>(projection_weights), |
2188 | GetTensorData<uint8_t>(projection_weights_ledger), |
2189 | GetTensorScale(projection_weights), |
2190 | GetTensorData<float>(projection_bias), params, |
2191 | /*n_batch=*/1, n_cell, n_input, aux_input_size, n_output, |
2192 | output_batch_leading_dim, input_gate_scratch_ptr, |
2193 | forget_gate_scratch_ptr, cell_gate_scratch_ptr, |
2194 | output_gate_scratch_ptr, GetTensorData<float>(input_sf), |
2195 | GetTensorData<float>(aux_input_sf), |
2196 | GetTensorData<float>(output_state_sf), |
2197 | GetTensorData<float>(prod_scaling_factors), |
2198 | GetTensorData<float>(recovered_cell_weights), |
2199 | GetTensorData<int8_t>(input_quantized), |
2200 | GetTensorData<int8_t>(aux_input_quantized), |
2201 | GetTensorData<int8_t>(output_state_quantized), |
2202 | GetTensorData<int8_t>(cell_state_quantized), output_state_ptr, |
2203 | cell_state_ptr, GetTensorData<int32_t>(output_scratch_buffer), |
2204 | output_ptr, input_zp_ptr, aux_input_zp_ptr, output_state_zp_ptr, |
2205 | row_sums_ptr, row_sums_size, compute_row_sums, |
2206 | params->asymmetric_quantize_inputs, context); |
2207 | } |
2208 | } |
2209 | } |
2210 | |
2211 | return kTfLiteOk; |
2212 | } |
2213 | |
2214 | TfLiteStatus EvalInteger8x8_16( |
2215 | const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, |
2216 | const TfLiteTensor* input_to_forget_weights, |
2217 | const TfLiteTensor* input_to_cell_weights, |
2218 | const TfLiteTensor* input_to_output_weights, |
2219 | const TfLiteTensor* recurrent_to_input_weights, |
2220 | const TfLiteTensor* recurrent_to_forget_weights, |
2221 | const TfLiteTensor* recurrent_to_cell_weights, |
2222 | const TfLiteTensor* recurrent_to_output_weights, |
2223 | const TfLiteTensor* cell_to_input_weights, |
2224 | const TfLiteTensor* cell_to_forget_weights, |
2225 | const TfLiteTensor* cell_to_output_weights, |
2226 | const TfLiteTensor* input_layer_norm_coefficients, |
2227 | const TfLiteTensor* forget_layer_norm_coefficients, |
2228 | const TfLiteTensor* cell_layer_norm_coefficients, |
2229 | const TfLiteTensor* output_layer_norm_coefficients, |
2230 | const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, |
2231 | const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, |
2232 | const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, |
2233 | const TfLiteLSTMParams* params, bool forward_sequence, bool time_major, |
2234 | const lstm_eval::IntegerLstmParameter* integer_lstm_param, |
2235 | TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output, |
2236 | TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2, |
2237 | TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5, |
2238 | CpuBackendContext* context) { |
2239 | TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); |
2240 | const int n_input = input->dims->data[input->dims->size - 1]; |
2241 | int max_time, n_batch; |
2242 | if (input->dims->size == 2) { |
2243 | max_time = 1; |
2244 | n_batch = input->dims->data[0]; |
2245 | } else { |
2246 | max_time = (time_major) ? input->dims->data[0] : input->dims->data[1]; |
2247 | n_batch = (time_major) ? input->dims->data[1] : input->dims->data[0]; |
2248 | } |
2249 | |
2250 | // n_cell and n_output will be the same size when there is no projection. |
2251 | const int n_cell = input_to_output_weights->dims->data[0]; |
2252 | const int n_output = recurrent_to_output_weights->dims->data[1]; |
2253 | |
2254 | // Activation zero point |
2255 | int output_state_zp = output_state->params.zero_point; |
2256 | |
2257 | // Get params for time/batch/sequence. |
2258 | const int output_batch_leading_dim = |
2259 | output->dims->data[output->dims->size - 1]; |
2260 | |
2261 | if (time_major) { |
2262 | const int input_step = n_batch * n_input; |
2263 | const int output_step = n_batch * output_batch_leading_dim; |
2264 | for (int t = 0; t < max_time; t++) { |
2265 | const int t_rel = t; |
2266 | int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step; |
2267 | const int8_t* input_ptr = |
2268 | GetTensorData<int8_t>(input) + t_rel * input_step; |
2269 | LstmStepInteger8x8_16( |
2270 | input_ptr, GetTensorData<int8_t>(input_to_input_weights), |
2271 | integer_lstm_param->effective_input_to_input_scale_a, |
2272 | integer_lstm_param->effective_input_to_input_scale_b, |
2273 | GetTensorData<int8_t>(input_to_forget_weights), |
2274 | integer_lstm_param->effective_input_to_forget_scale_a, |
2275 | integer_lstm_param->effective_input_to_forget_scale_b, |
2276 | GetTensorData<int8_t>(input_to_cell_weights), |
2277 | integer_lstm_param->effective_input_to_cell_scale_a, |
2278 | integer_lstm_param->effective_input_to_cell_scale_b, |
2279 | GetTensorData<int8_t>(input_to_output_weights), |
2280 | integer_lstm_param->effective_input_to_output_scale_a, |
2281 | integer_lstm_param->effective_input_to_output_scale_b, |
2282 | GetTensorData<int8_t>(recurrent_to_input_weights), |
2283 | integer_lstm_param->effective_recurrent_to_input_scale_a, |
2284 | integer_lstm_param->effective_recurrent_to_input_scale_b, |
2285 | GetTensorData<int8_t>(recurrent_to_forget_weights), |
2286 | integer_lstm_param->effective_recurrent_to_forget_scale_a, |
2287 | integer_lstm_param->effective_recurrent_to_forget_scale_b, |
2288 | GetTensorData<int8_t>(recurrent_to_cell_weights), |
2289 | integer_lstm_param->effective_recurrent_to_cell_scale_a, |
2290 | integer_lstm_param->effective_recurrent_to_cell_scale_b, |
2291 | GetTensorData<int8_t>(recurrent_to_output_weights), |
2292 | integer_lstm_param->effective_recurrent_to_output_scale_a, |
2293 | integer_lstm_param->effective_recurrent_to_output_scale_b, |
2294 | GetTensorData<int16_t>(cell_to_input_weights), |
2295 | integer_lstm_param->effective_cell_to_input_scale_a, |
2296 | integer_lstm_param->effective_cell_to_input_scale_b, |
2297 | GetTensorData<int16_t>(cell_to_forget_weights), |
2298 | integer_lstm_param->effective_cell_to_forget_scale_a, |
2299 | integer_lstm_param->effective_cell_to_forget_scale_b, |
2300 | GetTensorData<int16_t>(cell_to_output_weights), |
2301 | integer_lstm_param->effective_cell_to_output_scale_a, |
2302 | integer_lstm_param->effective_cell_to_output_scale_b, |
2303 | GetTensorData<int8_t>(projection_weights), |
2304 | integer_lstm_param->effective_proj_scale_a, |
2305 | integer_lstm_param->effective_proj_scale_b, |
2306 | integer_lstm_param->hidden_zp, |
2307 | integer_lstm_param->effective_hidden_scale_a, |
2308 | integer_lstm_param->effective_hidden_scale_b, |
2309 | GetTensorData<int16_t>(input_layer_norm_coefficients), |
2310 | integer_lstm_param->layer_norm_input_scale_a, |
2311 | integer_lstm_param->layer_norm_input_scale_b, |
2312 | GetTensorData<int16_t>(forget_layer_norm_coefficients), |
2313 | integer_lstm_param->layer_norm_forget_scale_a, |
2314 | integer_lstm_param->layer_norm_forget_scale_b, |
2315 | GetTensorData<int16_t>(cell_layer_norm_coefficients), |
2316 | integer_lstm_param->layer_norm_cell_scale_a, |
2317 | integer_lstm_param->layer_norm_cell_scale_b, |
2318 | GetTensorData<int16_t>(output_layer_norm_coefficients), |
2319 | integer_lstm_param->layer_norm_output_scale_a, |
2320 | integer_lstm_param->layer_norm_output_scale_b, |
2321 | GetTensorData<int32_t>(input_gate_bias), |
2322 | GetTensorData<int32_t>(forget_gate_bias), |
2323 | GetTensorData<int32_t>(cell_gate_bias), |
2324 | GetTensorData<int32_t>(output_gate_bias), |
2325 | integer_lstm_param->quantized_cell_clip, |
2326 | integer_lstm_param->quantized_proj_clip, |
2327 | integer_lstm_param->cell_scale, |
2328 | integer_lstm_param->input_variance_guard, |
2329 | integer_lstm_param->forget_variance_guard, |
2330 | integer_lstm_param->cell_variance_guard, |
2331 | integer_lstm_param->output_variance_guard, |
2332 | integer_lstm_param->input_to_forget_effective_bias.get(), |
2333 | integer_lstm_param->recurrent_to_forget_effective_bias.get(), |
2334 | integer_lstm_param->input_to_cell_effective_bias.get(), |
2335 | integer_lstm_param->recurrent_to_cell_effective_bias.get(), |
2336 | integer_lstm_param->input_to_output_effective_bias.get(), |
2337 | integer_lstm_param->recurrent_to_output_effective_bias.get(), |
2338 | integer_lstm_param->input_to_input_effective_bias.get(), |
2339 | integer_lstm_param->recurrent_to_input_effective_bias.get(), |
2340 | integer_lstm_param->projection_effective_bias.get(), n_batch, n_cell, |
2341 | n_input, n_output, GetTensorData<int8_t>(output_state), |
2342 | output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr, |
2343 | GetTensorData<int16_t>(scratch0), GetTensorData<int16_t>(scratch1), |
2344 | GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3), |
2345 | GetTensorData<int8_t>(scratch4), GetTensorData<int32_t>(scratch5), |
2346 | context); |
2347 | } |
2348 | } else { |
2349 | for (int b = 0; b < n_batch; b++) { |
2350 | const int input_step = n_input; |
2351 | const int output_step = output_batch_leading_dim; |
2352 | for (int t = 0; t < max_time; t++) { |
2353 | // If this is the forward_sequence, step forward, otherwise step |
2354 | // backwards. |
2355 | const int t_rel = forward_sequence ? t : max_time - t - 1; |
2356 | const int time_offset = b * max_time + t_rel; |
2357 | const int8_t* input_ptr = |
2358 | GetTensorData<int8_t>(input) + time_offset * input_step; |
2359 | int8_t* output_ptr = |
2360 | GetTensorData<int8_t>(output) + time_offset * output_step; |
2361 | |
2362 | // Offset the {output,cell}_state pointers to the right batch. |
2363 | int8_t* output_state_ptr = |
2364 | GetTensorData<int8_t>(output_state) + b * output_batch_leading_dim; |
2365 | int16_t* cell_state_ptr = |
2366 | GetTensorData<int16_t>(cell_state) + b * n_cell; |
2367 | |
2368 | LstmStepInteger8x8_16( |
2369 | input_ptr, GetTensorData<int8_t>(input_to_input_weights), |
2370 | integer_lstm_param->effective_input_to_input_scale_a, |
2371 | integer_lstm_param->effective_input_to_input_scale_b, |
2372 | GetTensorData<int8_t>(input_to_forget_weights), |
2373 | integer_lstm_param->effective_input_to_forget_scale_a, |
2374 | integer_lstm_param->effective_input_to_forget_scale_b, |
2375 | GetTensorData<int8_t>(input_to_cell_weights), |
2376 | integer_lstm_param->effective_input_to_cell_scale_a, |
2377 | integer_lstm_param->effective_input_to_cell_scale_b, |
2378 | GetTensorData<int8_t>(input_to_output_weights), |
2379 | integer_lstm_param->effective_input_to_output_scale_a, |
2380 | integer_lstm_param->effective_input_to_output_scale_b, |
2381 | GetTensorData<int8_t>(recurrent_to_input_weights), |
2382 | integer_lstm_param->effective_recurrent_to_input_scale_a, |
2383 | integer_lstm_param->effective_recurrent_to_input_scale_b, |
2384 | GetTensorData<int8_t>(recurrent_to_forget_weights), |
2385 | integer_lstm_param->effective_recurrent_to_forget_scale_a, |
2386 | integer_lstm_param->effective_recurrent_to_forget_scale_b, |
2387 | GetTensorData<int8_t>(recurrent_to_cell_weights), |
2388 | integer_lstm_param->effective_recurrent_to_cell_scale_a, |
2389 | integer_lstm_param->effective_recurrent_to_cell_scale_b, |
2390 | GetTensorData<int8_t>(recurrent_to_output_weights), |
2391 | integer_lstm_param->effective_recurrent_to_output_scale_a, |
2392 | integer_lstm_param->effective_recurrent_to_output_scale_b, |
2393 | GetTensorData<int16_t>(cell_to_input_weights), |
2394 | integer_lstm_param->effective_cell_to_input_scale_a, |
2395 | integer_lstm_param->effective_cell_to_input_scale_b, |
2396 | GetTensorData<int16_t>(cell_to_forget_weights), |
2397 | integer_lstm_param->effective_cell_to_forget_scale_a, |
2398 | integer_lstm_param->effective_cell_to_forget_scale_b, |
2399 | GetTensorData<int16_t>(cell_to_output_weights), |
2400 | integer_lstm_param->effective_cell_to_output_scale_a, |
2401 | integer_lstm_param->effective_cell_to_output_scale_b, |
2402 | GetTensorData<int8_t>(projection_weights), |
2403 | integer_lstm_param->effective_proj_scale_a, |
2404 | integer_lstm_param->effective_proj_scale_b, |
2405 | integer_lstm_param->hidden_zp, |
2406 | integer_lstm_param->effective_hidden_scale_a, |
2407 | integer_lstm_param->effective_hidden_scale_b, |
2408 | GetTensorData<int16_t>(input_layer_norm_coefficients), |
2409 | integer_lstm_param->layer_norm_input_scale_a, |
2410 | integer_lstm_param->layer_norm_input_scale_b, |
2411 | GetTensorData<int16_t>(forget_layer_norm_coefficients), |
2412 | integer_lstm_param->layer_norm_forget_scale_a, |
2413 | integer_lstm_param->layer_norm_forget_scale_b, |
2414 | GetTensorData<int16_t>(cell_layer_norm_coefficients), |
2415 | integer_lstm_param->layer_norm_cell_scale_a, |
2416 | integer_lstm_param->layer_norm_cell_scale_b, |
2417 | GetTensorData<int16_t>(output_layer_norm_coefficients), |
2418 | integer_lstm_param->layer_norm_output_scale_a, |
2419 | integer_lstm_param->layer_norm_output_scale_b, |
2420 | GetTensorData<int32_t>(input_gate_bias), |
2421 | GetTensorData<int32_t>(forget_gate_bias), |
2422 | GetTensorData<int32_t>(cell_gate_bias), |
2423 | GetTensorData<int32_t>(output_gate_bias), |
2424 | integer_lstm_param->quantized_cell_clip, |
2425 | integer_lstm_param->quantized_proj_clip, |
2426 | integer_lstm_param->cell_scale, |
2427 | integer_lstm_param->input_variance_guard, |
2428 | integer_lstm_param->forget_variance_guard, |
2429 | integer_lstm_param->cell_variance_guard, |
2430 | integer_lstm_param->output_variance_guard, |
2431 | integer_lstm_param->input_to_forget_effective_bias.get(), |
2432 | integer_lstm_param->recurrent_to_forget_effective_bias.get(), |
2433 | integer_lstm_param->input_to_cell_effective_bias.get(), |
2434 | integer_lstm_param->recurrent_to_cell_effective_bias.get(), |
2435 | integer_lstm_param->input_to_output_effective_bias.get(), |
2436 | integer_lstm_param->recurrent_to_output_effective_bias.get(), |
2437 | integer_lstm_param->input_to_input_effective_bias.get(), |
2438 | integer_lstm_param->recurrent_to_input_effective_bias.get(), |
2439 | integer_lstm_param->projection_effective_bias.get(), /*n_batch=*/1, |
2440 | n_cell, n_input, n_output, output_state_ptr, output_state_zp, |
2441 | cell_state_ptr, output_ptr, GetTensorData<int16_t>(scratch0), |
2442 | GetTensorData<int16_t>(scratch1), GetTensorData<int16_t>(scratch2), |
2443 | GetTensorData<int16_t>(scratch3), GetTensorData<int8_t>(scratch4), |
2444 | GetTensorData<int32_t>(scratch5), context); |
2445 | } |
2446 | } |
2447 | } |
2448 | |
2449 | return kTfLiteOk; |
2450 | } |
2451 | |
2452 | TfLiteStatus EvalInteger8x8_8( |
2453 | const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, |
2454 | const TfLiteTensor* input_to_forget_weights, |
2455 | const TfLiteTensor* input_to_cell_weights, |
2456 | const TfLiteTensor* input_to_output_weights, |
2457 | const TfLiteTensor* recurrent_to_input_weights, |
2458 | const TfLiteTensor* recurrent_to_forget_weights, |
2459 | const TfLiteTensor* recurrent_to_cell_weights, |
2460 | const TfLiteTensor* recurrent_to_output_weights, |
2461 | const TfLiteTensor* cell_to_input_weights, |
2462 | const TfLiteTensor* cell_to_forget_weights, |
2463 | const TfLiteTensor* cell_to_output_weights, |
2464 | const TfLiteTensor* input_layer_norm_coefficients, |
2465 | const TfLiteTensor* forget_layer_norm_coefficients, |
2466 | const TfLiteTensor* cell_layer_norm_coefficients, |
2467 | const TfLiteTensor* output_layer_norm_coefficients, |
2468 | const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, |
2469 | const TfLiteTensor* cell_gate_bias, const TfLiteTensor* output_gate_bias, |
2470 | const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, |
2471 | const TfLiteLSTMParams* params, TfLiteTensor* output_state, |
2472 | TfLiteTensor* cell_state, TfLiteTensor* output, |
2473 | const lstm_eval::IntegerLstmParameter* integer_lstm_param, |
2474 | TfLiteTensor* scratch0, TfLiteTensor* scratch1, TfLiteTensor* scratch2, |
2475 | TfLiteTensor* scratch3, TfLiteTensor* scratch4, TfLiteTensor* scratch5, |
2476 | TfLiteTensor* scratch6, TfLiteTensor* scratch7) { |
2477 | TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3); |
2478 | const int n_input = input->dims->data[input->dims->size - 1]; |
2479 | int max_time, n_batch; |
2480 | if (input->dims->size == 2) { |
2481 | max_time = 1; |
2482 | n_batch = input->dims->data[0]; |
2483 | } else { |
2484 | max_time = input->dims->data[0]; |
2485 | n_batch = input->dims->data[1]; |
2486 | } |
2487 | |
2488 | // n_cell and n_output will be the same size when there is no projection. |
2489 | const int n_cell = input_to_output_weights->dims->data[0]; |
2490 | const int n_output = recurrent_to_output_weights->dims->data[1]; |
2491 | |
2492 | const int32_t input_zp = input->params.zero_point; |
2493 | const int32_t output_state_zp = output_state->params.zero_point; |
2494 | |
2495 | // Get params for time/batch/sequence. |
2496 | const int output_batch_leading_dim = |
2497 | output->dims->data[output->dims->size - 1]; |
2498 | const int input_step = n_batch * n_input; |
2499 | const int output_step = n_batch * output_batch_leading_dim; |
2500 | |
2501 | for (int t = 0; t < max_time; t++) { |
2502 | const int t_rel = t; |
2503 | int8_t* output_ptr = GetTensorData<int8_t>(output) + t_rel * output_step; |
2504 | // Input can be int8 asymmetric or int16 symmetric. |
2505 | const int8_t* input_ptr = GetTensorData<int8_t>(input) + t_rel * input_step; |
2506 | lstm_eval::LstmStepInteger8x8_8( |
2507 | input_ptr, input_zp, |
2508 | |
2509 | GetTensorData<int8_t>(input_to_input_weights), |
2510 | integer_lstm_param->effective_input_to_input_scale_a, |
2511 | integer_lstm_param->effective_input_to_input_scale_b, |
2512 | |
2513 | GetTensorData<int8_t>(input_to_forget_weights), |
2514 | integer_lstm_param->effective_input_to_forget_scale_a, |
2515 | integer_lstm_param->effective_input_to_forget_scale_b, |
2516 | |
2517 | GetTensorData<int8_t>(input_to_cell_weights), |
2518 | integer_lstm_param->effective_input_to_cell_scale_a, |
2519 | integer_lstm_param->effective_input_to_cell_scale_b, |
2520 | |
2521 | GetTensorData<int8_t>(input_to_output_weights), |
2522 | integer_lstm_param->effective_input_to_output_scale_a, |
2523 | integer_lstm_param->effective_input_to_output_scale_b, |
2524 | |
2525 | GetTensorData<int8_t>(recurrent_to_input_weights), |
2526 | integer_lstm_param->effective_recurrent_to_input_scale_a, |
2527 | integer_lstm_param->effective_recurrent_to_input_scale_b, |
2528 | |
2529 | GetTensorData<int8_t>(recurrent_to_forget_weights), |
2530 | integer_lstm_param->effective_recurrent_to_forget_scale_a, |
2531 | integer_lstm_param->effective_recurrent_to_forget_scale_b, |
2532 | |
2533 | GetTensorData<int8_t>(recurrent_to_cell_weights), |
2534 | integer_lstm_param->effective_recurrent_to_cell_scale_a, |
2535 | integer_lstm_param->effective_recurrent_to_cell_scale_b, |
2536 | |
2537 | GetTensorData<int8_t>(recurrent_to_output_weights), |
2538 | integer_lstm_param->effective_recurrent_to_output_scale_a, |
2539 | integer_lstm_param->effective_recurrent_to_output_scale_b, |
2540 | |
2541 | GetTensorData<int8_t>(cell_to_input_weights), |
2542 | integer_lstm_param->effective_cell_to_input_scale_a, |
2543 | integer_lstm_param->effective_cell_to_input_scale_b, |
2544 | |
2545 | GetTensorData<int8_t>(cell_to_forget_weights), |
2546 | integer_lstm_param->effective_cell_to_forget_scale_a, |
2547 | integer_lstm_param->effective_cell_to_forget_scale_b, |
2548 | |
2549 | GetTensorData<int8_t>(cell_to_output_weights), |
2550 | integer_lstm_param->effective_cell_to_output_scale_a, |
2551 | integer_lstm_param->effective_cell_to_output_scale_b, |
2552 | |
2553 | GetTensorData<int8_t>(projection_weights), |
2554 | integer_lstm_param->effective_proj_scale_a, |
2555 | integer_lstm_param->effective_proj_scale_b, |
2556 | |
2557 | GetTensorData<int16_t>(input_layer_norm_coefficients), |
2558 | integer_lstm_param->layer_norm_input_scale_a, |
2559 | integer_lstm_param->layer_norm_input_scale_b, |
2560 | |
2561 | GetTensorData<int16_t>(forget_layer_norm_coefficients), |
2562 | integer_lstm_param->layer_norm_forget_scale_a, |
2563 | integer_lstm_param->layer_norm_forget_scale_b, |
2564 | |
2565 | GetTensorData<int16_t>(cell_layer_norm_coefficients), |
2566 | integer_lstm_param->layer_norm_cell_scale_a, |
2567 | integer_lstm_param->layer_norm_cell_scale_b, |
2568 | |
2569 | GetTensorData<int16_t>(output_layer_norm_coefficients), |
2570 | integer_lstm_param->layer_norm_output_scale_a, |
2571 | integer_lstm_param->layer_norm_output_scale_b, |
2572 | |
2573 | GetTensorData<int32_t>(input_gate_bias), |
2574 | GetTensorData<int32_t>(forget_gate_bias), |
2575 | GetTensorData<int32_t>(cell_gate_bias), |
2576 | GetTensorData<int32_t>(output_gate_bias), |
2577 | GetTensorData<int32_t>(projection_bias), |
2578 | |
2579 | params, integer_lstm_param->intermediate_scale_a, |
2580 | integer_lstm_param->intermediate_scale_b, |
2581 | integer_lstm_param->intermediate_zp, |
2582 | integer_lstm_param->quantized_cell_clip, |
2583 | integer_lstm_param->quantized_proj_clip, n_batch, n_cell, n_input, |
2584 | n_output, output_batch_leading_dim, GetTensorData<int8_t>(output_state), |
2585 | output_state_zp, GetTensorData<int16_t>(cell_state), output_ptr, |
2586 | GetTensorData<int8_t>(scratch0), GetTensorData<int8_t>(scratch1), |
2587 | GetTensorData<int16_t>(scratch2), GetTensorData<int16_t>(scratch3), |
2588 | GetTensorData<int16_t>(scratch4), GetTensorData<int16_t>(scratch5), |
2589 | GetTensorData<int16_t>(scratch6), GetTensorData<int16_t>(scratch7)); |
2590 | } |
2591 | |
2592 | return kTfLiteOk; |
2593 | } |
2594 | |
2595 | } // namespace lstm_eval |
2596 | } // namespace builtin |
2597 | } // namespace ops |
2598 | } // namespace tflite |
2599 | |