1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <math.h> |
17 | |
18 | #include <algorithm> |
19 | #include <cstddef> |
20 | |
21 | #include "tensorflow/lite/c/builtin_op_data.h" |
22 | #include "tensorflow/lite/c/common.h" |
23 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
24 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
25 | #include "tensorflow/lite/kernels/internal/kernel_utils.h" |
26 | #include "tensorflow/lite/kernels/internal/quantization_util.h" |
27 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
28 | #include "tensorflow/lite/kernels/internal/tensor_utils.h" |
29 | #include "tensorflow/lite/kernels/kernel_util.h" |
30 | #include "tensorflow/lite/kernels/lstm_eval.h" |
31 | #include "tensorflow/lite/kernels/lstm_shared.h" |
32 | |
33 | namespace tflite { |
34 | namespace ops { |
35 | namespace builtin { |
36 | namespace unidirectional_sequence_lstm { |
37 | namespace { |
38 | |
39 | struct OpData { |
40 | // If the lstm is layer norm. |
41 | bool use_layer_norm; |
42 | // The scratch tensor index. |
43 | int scratch_tensor_index; |
44 | bool compute_row_sums = false; |
45 | |
46 | lstm_eval::IntegerLstmParameter integer_lstm_param; |
47 | }; |
48 | |
49 | TfLiteStatus PopulateQuantizedLstmParams8x8_16( |
50 | TfLiteContext* context, TfLiteNode* node, |
51 | lstm_eval::IntegerLstmParameter* integer_lstm_param) { |
52 | // Calculate quantized clip for projection and cell. |
53 | const auto* params = |
54 | static_cast<TfLiteUnidirectionalSequenceLSTMParams*>(node->builtin_data); |
55 | const float cell_clip = params->cell_clip; |
56 | const float proj_clip = params->proj_clip; |
57 | |
58 | const TfLiteTensor* cell_state = |
59 | GetVariableInput(context, node, lstm::full::kCellStateTensor); |
60 | TF_LITE_ENSURE(context, cell_state != nullptr); |
61 | TfLiteTensor* output_tensor; |
62 | TF_LITE_ENSURE_OK( |
63 | context, |
64 | GetOutputSafe(context, node, lstm::full::kOutputTensor, &output_tensor)); |
65 | |
66 | TF_LITE_ENSURE(context, |
67 | cell_state->quantization.type != kTfLiteNoQuantization); |
68 | auto* cell_state_params = |
69 | static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params); |
70 | TF_LITE_ENSURE(context, |
71 | output_tensor->quantization.type != kTfLiteNoQuantization); |
72 | auto* proj_params = static_cast<TfLiteAffineQuantization*>( |
73 | output_tensor->quantization.params); |
74 | if (cell_clip > 0.0) { |
75 | integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min( |
76 | std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f), |
77 | 32767.0f)); |
78 | } else { |
79 | integer_lstm_param->quantized_cell_clip = 0; |
80 | } |
81 | if (proj_clip > 0.0) { |
82 | integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min( |
83 | std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f)); |
84 | } else { |
85 | integer_lstm_param->quantized_proj_clip = 0; |
86 | } |
87 | |
88 | // Calculate effective scales. |
89 | OpData* op_data = static_cast<OpData*>(node->user_data); |
90 | const bool use_layer_norm = op_data->use_layer_norm; |
91 | |
92 | const TfLiteTensor* input; |
93 | TF_LITE_ENSURE_OK( |
94 | context, GetInputSafe(context, node, lstm::full::kInputTensor, &input)); |
95 | |
96 | const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor( |
97 | context, node, lstm::full::kInputToInputWeightsTensor); |
98 | const TfLiteTensor* input_to_forget_weights; |
99 | TF_LITE_ENSURE_OK( |
100 | context, |
101 | GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor, |
102 | &input_to_forget_weights)); |
103 | const TfLiteTensor* input_to_cell_weights; |
104 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, |
105 | lstm::full::kInputToCellWeightsTensor, |
106 | &input_to_cell_weights)); |
107 | const TfLiteTensor* input_to_output_weights; |
108 | TF_LITE_ENSURE_OK( |
109 | context, |
110 | GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor, |
111 | &input_to_output_weights)); |
112 | |
113 | const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor( |
114 | context, node, lstm::full::kRecurrentToInputWeightsTensor); |
115 | const TfLiteTensor* recurrent_to_forget_weights; |
116 | TF_LITE_ENSURE_OK( |
117 | context, |
118 | GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor, |
119 | &recurrent_to_forget_weights)); |
120 | const TfLiteTensor* recurrent_to_cell_weights; |
121 | TF_LITE_ENSURE_OK( |
122 | context, |
123 | GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor, |
124 | &recurrent_to_cell_weights)); |
125 | const TfLiteTensor* recurrent_to_output_weights; |
126 | TF_LITE_ENSURE_OK( |
127 | context, |
128 | GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor, |
129 | &recurrent_to_output_weights)); |
130 | |
131 | const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor( |
132 | context, node, lstm::full::kCellToInputWeightsTensor); |
133 | const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor( |
134 | context, node, lstm::full::kCellToForgetWeightsTensor); |
135 | const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor( |
136 | context, node, lstm::full::kCellToOutputWeightsTensor); |
137 | |
138 | const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( |
139 | context, node, lstm::full::kInputLayerNormCoefficientsTensor); |
140 | const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor( |
141 | context, node, lstm::full::kForgetLayerNormCoefficientsTensor); |
142 | const TfLiteTensor* cell_layer_norm_coefficients = GetOptionalInputTensor( |
143 | context, node, lstm::full::kCellLayerNormCoefficientsTensor); |
144 | const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor( |
145 | context, node, lstm::full::kOutputLayerNormCoefficientsTensor); |
146 | |
147 | const TfLiteTensor* projection_weights = GetOptionalInputTensor( |
148 | context, node, lstm::full::kProjectionWeightsTensor); |
149 | |
150 | TfLiteTensor* output_state = |
151 | GetVariableInput(context, node, lstm::full::kOutputStateTensor); |
152 | TF_LITE_ENSURE(context, output_state != nullptr); |
153 | |
154 | // Since we have already checked that weights are all there or none, we can |
155 | // check the existence of only one to get the condition. |
156 | const bool use_cifg = (input_to_input_weights == nullptr); |
157 | const bool use_peephole = (cell_to_output_weights != nullptr); |
158 | const bool use_projection = (projection_weights != nullptr); |
159 | |
160 | // Get intermediate scales and zero points. |
161 | std::vector<float> intermediate_scale; |
162 | std::vector<int32> intermediate_zp; |
163 | for (int i = 0; i < 4; ++i) { |
164 | if (use_layer_norm) { |
165 | TfLiteTensor* intermediate; |
166 | TF_LITE_ENSURE_OK(context, |
167 | GetIntermediatesSafe(context, node, i, &intermediate)); |
168 | TF_LITE_ENSURE(context, |
169 | intermediate->quantization.type != kTfLiteNoQuantization); |
170 | auto* params = static_cast<TfLiteAffineQuantization*>( |
171 | intermediate->quantization.params); |
172 | intermediate_scale.push_back(params->scale->data[0]); |
173 | intermediate_zp.push_back(params->zero_point->data[0]); |
174 | } else { |
175 | // Q3.12 for activation functions. |
176 | intermediate_scale.push_back(std::pow(2, -12)); |
177 | intermediate_zp.push_back(0); |
178 | } |
179 | } |
180 | // In the absence of projection, hidden becomes otuput and this intermediate |
181 | // is ignored. |
182 | TfLiteTensor* hidden; |
183 | TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden)); |
184 | TF_LITE_ENSURE(context, hidden->quantization.type != kTfLiteNoQuantization); |
185 | auto* hidden_params = |
186 | static_cast<TfLiteAffineQuantization*>(hidden->quantization.params); |
187 | intermediate_scale.push_back(hidden_params->scale->data[0]); |
188 | intermediate_zp.push_back(hidden_params->zero_point->data[0]); |
189 | |
190 | // Scales. |
191 | const float default_scale = 1.0; |
192 | float input_scale = default_scale; |
193 | float input_to_input_weight_scale = default_scale; |
194 | float recurrent_to_input_weight_scale = default_scale; |
195 | float cell_to_input_weight_scale = default_scale; |
196 | float input_to_forget_weight_scale = default_scale; |
197 | float recurrent_to_forget_weight_scale = default_scale; |
198 | float cell_to_forget_weight_scale = default_scale; |
199 | float input_to_cell_weight_scale = default_scale; |
200 | float recurrent_to_cell_weight_scale = default_scale; |
201 | float input_to_output_weight_scale = default_scale; |
202 | float recurrent_to_output_weight_scale = default_scale; |
203 | float cell_to_output_weight_scale = default_scale; |
204 | float projection_weight_scale = default_scale; |
205 | float layer_norm_input_scale = default_scale; |
206 | float layer_norm_forget_scale = default_scale; |
207 | float layer_norm_cell_scale = default_scale; |
208 | float layer_norm_output_scale = default_scale; |
209 | float output_state_scale = default_scale; |
210 | int cell_scale = 1; |
211 | |
212 | // Effective scales. |
213 | float effective_input_to_input_scale = default_scale; |
214 | float effective_recurrent_to_input_scale = default_scale; |
215 | float effective_cell_to_input_scale = default_scale; |
216 | float effective_input_to_forget_scale = default_scale; |
217 | float effective_recurrent_to_forget_scale = default_scale; |
218 | float effective_cell_to_forget_scale = default_scale; |
219 | float effective_input_to_cell_scale = default_scale; |
220 | float effective_recurrent_to_cell_scale = default_scale; |
221 | float effective_input_to_output_scale = default_scale; |
222 | float effective_recurrent_to_output_scale = default_scale; |
223 | float effective_cell_to_output_scale = default_scale; |
224 | float effective_proj_scale = default_scale; |
225 | float effective_hidden_scale = default_scale; |
226 | |
227 | // Populate scales. |
228 | if (!use_cifg) { |
229 | input_to_input_weight_scale = input_to_input_weights->params.scale; |
230 | recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale; |
231 | } |
232 | |
233 | if (use_peephole) { |
234 | if (!use_cifg) { |
235 | cell_to_input_weight_scale = cell_to_input_weights->params.scale; |
236 | } |
237 | cell_to_forget_weight_scale = cell_to_forget_weights->params.scale; |
238 | cell_to_output_weight_scale = cell_to_output_weights->params.scale; |
239 | } |
240 | |
241 | if (use_layer_norm) { |
242 | if (!use_cifg) { |
243 | layer_norm_input_scale = input_layer_norm_coefficients->params.scale; |
244 | } |
245 | layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale; |
246 | layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale; |
247 | layer_norm_output_scale = output_layer_norm_coefficients->params.scale; |
248 | } |
249 | |
250 | if (use_projection) { |
251 | projection_weight_scale = projection_weights->params.scale; |
252 | } |
253 | output_state_scale = output_state->params.scale; |
254 | |
255 | input_to_forget_weight_scale = input_to_forget_weights->params.scale; |
256 | input_to_cell_weight_scale = input_to_cell_weights->params.scale; |
257 | input_to_output_weight_scale = input_to_output_weights->params.scale; |
258 | recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale; |
259 | recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale; |
260 | recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale; |
261 | |
262 | // Check cell state (already used above) |
263 | TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale)); |
264 | // TF_LITE_ENSURE(context, cell_scale <= -9); |
265 | integer_lstm_param->cell_scale = cell_scale; |
266 | input_scale = input->params.scale; |
267 | |
268 | // Calculate effective scales. |
269 | if (!use_cifg) { |
270 | effective_input_to_input_scale = |
271 | input_to_input_weight_scale * input_scale / intermediate_scale[0]; |
272 | effective_recurrent_to_input_scale = recurrent_to_input_weight_scale * |
273 | output_state_scale / |
274 | intermediate_scale[0]; |
275 | } |
276 | effective_input_to_forget_scale = |
277 | input_to_forget_weight_scale * input_scale / intermediate_scale[1]; |
278 | effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale * |
279 | output_state_scale / |
280 | intermediate_scale[1]; |
281 | |
282 | effective_input_to_cell_scale = |
283 | input_to_cell_weight_scale * input_scale / intermediate_scale[2]; |
284 | effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale * |
285 | output_state_scale / |
286 | intermediate_scale[2]; |
287 | |
288 | effective_input_to_output_scale = |
289 | input_to_output_weight_scale * input_scale / intermediate_scale[3]; |
290 | effective_recurrent_to_output_scale = recurrent_to_output_weight_scale * |
291 | output_state_scale / |
292 | intermediate_scale[3]; |
293 | |
294 | effective_hidden_scale = |
295 | std::pow(2, -15) / intermediate_scale[4] * std::pow(2, -15); |
296 | |
297 | effective_proj_scale = |
298 | projection_weight_scale * intermediate_scale[4] / output_state_scale; |
299 | |
300 | if (use_peephole) { |
301 | if (!use_cifg) { |
302 | effective_cell_to_input_scale = std::pow(2, cell_scale) * // NOLINT |
303 | cell_to_input_weight_scale / |
304 | intermediate_scale[0]; |
305 | } |
306 | effective_cell_to_forget_scale = std::pow(2, cell_scale) * // NOLINT |
307 | cell_to_forget_weight_scale / |
308 | intermediate_scale[1]; |
309 | effective_cell_to_output_scale = std::pow(2, cell_scale) * // NOLINT |
310 | cell_to_output_weight_scale / |
311 | intermediate_scale[3]; |
312 | } |
313 | |
314 | // Decompose scales. |
315 | QuantizeMultiplier(effective_input_to_input_scale, |
316 | &integer_lstm_param->effective_input_to_input_scale_a, |
317 | &integer_lstm_param->effective_input_to_input_scale_b); |
318 | QuantizeMultiplier(effective_recurrent_to_input_scale, |
319 | &integer_lstm_param->effective_recurrent_to_input_scale_a, |
320 | &integer_lstm_param->effective_recurrent_to_input_scale_b); |
321 | QuantizeMultiplier(effective_cell_to_input_scale, |
322 | &integer_lstm_param->effective_cell_to_input_scale_a, |
323 | &integer_lstm_param->effective_cell_to_input_scale_b); |
324 | QuantizeMultiplier(effective_input_to_forget_scale, |
325 | &integer_lstm_param->effective_input_to_forget_scale_a, |
326 | &integer_lstm_param->effective_input_to_forget_scale_b); |
327 | QuantizeMultiplier( |
328 | effective_recurrent_to_forget_scale, |
329 | &integer_lstm_param->effective_recurrent_to_forget_scale_a, |
330 | &integer_lstm_param->effective_recurrent_to_forget_scale_b); |
331 | QuantizeMultiplier(effective_cell_to_forget_scale, |
332 | &integer_lstm_param->effective_cell_to_forget_scale_a, |
333 | &integer_lstm_param->effective_cell_to_forget_scale_b); |
334 | QuantizeMultiplier(effective_input_to_cell_scale, |
335 | &integer_lstm_param->effective_input_to_cell_scale_a, |
336 | &integer_lstm_param->effective_input_to_cell_scale_b); |
337 | QuantizeMultiplier(effective_recurrent_to_cell_scale, |
338 | &integer_lstm_param->effective_recurrent_to_cell_scale_a, |
339 | &integer_lstm_param->effective_recurrent_to_cell_scale_b); |
340 | QuantizeMultiplier(effective_input_to_output_scale, |
341 | &integer_lstm_param->effective_input_to_output_scale_a, |
342 | &integer_lstm_param->effective_input_to_output_scale_b); |
343 | QuantizeMultiplier( |
344 | effective_recurrent_to_output_scale, |
345 | &integer_lstm_param->effective_recurrent_to_output_scale_a, |
346 | &integer_lstm_param->effective_recurrent_to_output_scale_b); |
347 | QuantizeMultiplier(effective_cell_to_output_scale, |
348 | &integer_lstm_param->effective_cell_to_output_scale_a, |
349 | &integer_lstm_param->effective_cell_to_output_scale_b); |
350 | QuantizeMultiplier(effective_proj_scale, |
351 | &integer_lstm_param->effective_proj_scale_a, |
352 | &integer_lstm_param->effective_proj_scale_b); |
353 | QuantizeMultiplier(effective_hidden_scale, |
354 | &integer_lstm_param->effective_hidden_scale_a, |
355 | &integer_lstm_param->effective_hidden_scale_b); |
356 | QuantizeMultiplier(layer_norm_input_scale, |
357 | &integer_lstm_param->layer_norm_input_scale_a, |
358 | &integer_lstm_param->layer_norm_input_scale_b); |
359 | QuantizeMultiplier(layer_norm_forget_scale, |
360 | &integer_lstm_param->layer_norm_forget_scale_a, |
361 | &integer_lstm_param->layer_norm_forget_scale_b); |
362 | QuantizeMultiplier(layer_norm_cell_scale, |
363 | &integer_lstm_param->layer_norm_cell_scale_a, |
364 | &integer_lstm_param->layer_norm_cell_scale_b); |
365 | QuantizeMultiplier(layer_norm_output_scale, |
366 | &integer_lstm_param->layer_norm_output_scale_a, |
367 | &integer_lstm_param->layer_norm_output_scale_b); |
368 | |
369 | integer_lstm_param->hidden_zp = intermediate_zp[4]; |
370 | |
371 | // 10000 is used to make sure the kernel logic does not overflow. |
372 | if (!use_cifg) { |
373 | integer_lstm_param->input_variance_guard = |
374 | std::max(1, static_cast<int32_t>(10000 * layer_norm_input_scale)); |
375 | } |
376 | integer_lstm_param->forget_variance_guard = |
377 | std::max(1, static_cast<int32_t>(10000 * layer_norm_forget_scale)); |
378 | integer_lstm_param->cell_variance_guard = |
379 | std::max(1, static_cast<int32_t>(10000 * layer_norm_cell_scale)); |
380 | integer_lstm_param->output_variance_guard = |
381 | std::max(1, static_cast<int32_t>(10000 * layer_norm_output_scale)); |
382 | |
383 | return kTfLiteOk; |
384 | } |
385 | |
386 | } // namespace |
387 | |
388 | // Temporary tensors |
389 | enum TemporaryTensor { |
390 | kScratchBuffer = 0, |
391 | kInputQuantized = 1, |
392 | kOutputStateQuantized = 2, |
393 | kCellStateQuantized = 3, |
394 | kInputScalingFactors = 4, |
395 | kOutputStateScalingFactors = 5, |
396 | kProductScalingFactors = 6, |
397 | kRecoveredCellWeights = 7, |
398 | kAccumScratch = 8, |
399 | kInputZeroPoints = 9, |
400 | kOutputStateZeroPoints = 10, |
401 | kRowSums = 11, |
402 | kNumTemporaryTensors = 12, |
403 | }; |
404 | |
405 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
406 | auto* op_data = new OpData(); |
407 | context->AddTensors(context, kNumTemporaryTensors, |
408 | &op_data->scratch_tensor_index); |
409 | return op_data; |
410 | } |
411 | |
412 | void Free(TfLiteContext* context, void* buffer) { |
413 | delete reinterpret_cast<OpData*>(buffer); |
414 | } |
415 | |
416 | // Check that input tensor dimensions matches with each other. |
417 | TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, |
418 | TfLiteNode* node, int n_input, |
419 | int n_output, int n_cell, |
420 | bool use_layer_norm, bool is_integer) { |
421 | const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); |
422 | |
423 | // Making sure clipping parameters have valid values. |
424 | // == 0 means no clipping |
425 | // > 0 means clipping |
426 | TF_LITE_ENSURE(context, params->cell_clip >= 0); |
427 | TF_LITE_ENSURE(context, params->proj_clip >= 0); |
428 | |
429 | const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor( |
430 | context, node, lstm::full::kInputToInputWeightsTensor); |
431 | if (input_to_input_weights != nullptr) { |
432 | TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); |
433 | TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); |
434 | TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); |
435 | } |
436 | |
437 | const TfLiteTensor* input_to_forget_weights; |
438 | TF_LITE_ENSURE_OK( |
439 | context, |
440 | GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor, |
441 | &input_to_forget_weights)); |
442 | TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); |
443 | TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); |
444 | TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); |
445 | |
446 | const TfLiteTensor* input_to_cell_weights; |
447 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, |
448 | lstm::full::kInputToCellWeightsTensor, |
449 | &input_to_cell_weights)); |
450 | TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); |
451 | TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); |
452 | TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); |
453 | |
454 | const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor( |
455 | context, node, lstm::full::kRecurrentToInputWeightsTensor); |
456 | if (recurrent_to_input_weights != nullptr) { |
457 | TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); |
458 | TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], |
459 | n_cell); |
460 | TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], |
461 | n_output); |
462 | } |
463 | |
464 | const TfLiteTensor* recurrent_to_forget_weights; |
465 | TF_LITE_ENSURE_OK( |
466 | context, |
467 | GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor, |
468 | &recurrent_to_forget_weights)); |
469 | TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); |
470 | TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], |
471 | n_cell); |
472 | TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], |
473 | n_output); |
474 | |
475 | const TfLiteTensor* recurrent_to_cell_weights; |
476 | TF_LITE_ENSURE_OK( |
477 | context, |
478 | GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor, |
479 | &recurrent_to_cell_weights)); |
480 | TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); |
481 | TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); |
482 | TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], |
483 | n_output); |
484 | |
485 | // We make sure the input-gate's parameters are either both present (regular |
486 | // LSTM) or not at all (CIFG-LSTM). |
487 | const bool cifg_weights_all_or_none = |
488 | ((input_to_input_weights != nullptr) && |
489 | (recurrent_to_input_weights != nullptr)) || |
490 | ((input_to_input_weights == nullptr) && |
491 | (recurrent_to_input_weights == nullptr)); |
492 | TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); |
493 | |
494 | const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor( |
495 | context, node, lstm::full::kCellToInputWeightsTensor); |
496 | if (cell_to_input_weights != nullptr) { |
497 | TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); |
498 | TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); |
499 | TF_LITE_ENSURE_TYPES_EQ( |
500 | context, cell_to_input_weights->type, |
501 | is_integer ? kTfLiteInt16 : input_to_forget_weights->type); |
502 | } |
503 | |
504 | const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor( |
505 | context, node, lstm::full::kCellToForgetWeightsTensor); |
506 | if (cell_to_forget_weights != nullptr) { |
507 | TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); |
508 | TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); |
509 | TF_LITE_ENSURE_TYPES_EQ( |
510 | context, cell_to_forget_weights->type, |
511 | is_integer ? kTfLiteInt16 : input_to_forget_weights->type); |
512 | } |
513 | |
514 | const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor( |
515 | context, node, lstm::full::kCellToOutputWeightsTensor); |
516 | if (cell_to_output_weights != nullptr) { |
517 | TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); |
518 | TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); |
519 | TF_LITE_ENSURE_TYPES_EQ( |
520 | context, cell_to_output_weights->type, |
521 | is_integer ? kTfLiteInt16 : input_to_forget_weights->type); |
522 | } |
523 | |
524 | // Making sure the peephole weights are there all or none. |
525 | const bool use_cifg = (input_to_input_weights == nullptr); |
526 | const bool peephole_weights_all_or_none = |
527 | ((cell_to_input_weights != nullptr || use_cifg) && |
528 | (cell_to_forget_weights != nullptr) && |
529 | (cell_to_output_weights != nullptr)) || |
530 | ((cell_to_input_weights == nullptr) && |
531 | (cell_to_forget_weights == nullptr) && |
532 | (cell_to_output_weights == nullptr)); |
533 | TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); |
534 | |
535 | // Make sure the input gate bias is present only when not a CIFG-LSTM. |
536 | const TfLiteTensor* input_gate_bias = |
537 | GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor); |
538 | if (use_cifg) { |
539 | TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); |
540 | } else { |
541 | TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); |
542 | TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); |
543 | if (is_integer) { |
544 | TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32); |
545 | } else { |
546 | TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32); |
547 | } |
548 | } |
549 | |
550 | const TfLiteTensor* forget_gate_bias; |
551 | TF_LITE_ENSURE_OK( |
552 | context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor, |
553 | &forget_gate_bias)); |
554 | TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); |
555 | TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); |
556 | if (is_integer) { |
557 | TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32); |
558 | } else { |
559 | TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32); |
560 | } |
561 | |
562 | const TfLiteTensor* cell_gate_bias; |
563 | TF_LITE_ENSURE_OK(context, |
564 | GetInputSafe(context, node, lstm::full::kCellGateBiasTensor, |
565 | &cell_gate_bias)); |
566 | TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1); |
567 | TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell); |
568 | if (is_integer) { |
569 | TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32); |
570 | } else { |
571 | TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32); |
572 | } |
573 | |
574 | const TfLiteTensor* output_gate_bias; |
575 | TF_LITE_ENSURE_OK( |
576 | context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor, |
577 | &output_gate_bias)); |
578 | TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); |
579 | TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); |
580 | if (is_integer) { |
581 | TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32); |
582 | } else { |
583 | TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32); |
584 | } |
585 | |
586 | const TfLiteTensor* projection_weights = GetOptionalInputTensor( |
587 | context, node, lstm::full::kProjectionWeightsTensor); |
588 | if (projection_weights != nullptr) { |
589 | TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); |
590 | TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); |
591 | TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); |
592 | } |
593 | |
594 | const TfLiteTensor* projection_bias = |
595 | GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor); |
596 | if (projection_bias != nullptr) { |
597 | TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); |
598 | TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); |
599 | if (is_integer) { |
600 | TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32); |
601 | } else { |
602 | TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32); |
603 | } |
604 | } |
605 | |
606 | // Making sure the projection tensors are consistent: |
607 | // 1) If projection weight is not present, then projection bias should not be |
608 | // present. |
609 | // 2) If projection weight is present, then projection bias is optional. |
610 | // TODO(ghodrat): make sure this is correct. |
611 | const bool projecton_tensors_consistent = |
612 | ((projection_weights != nullptr) || (projection_bias == nullptr)); |
613 | TF_LITE_ENSURE(context, projecton_tensors_consistent == true); |
614 | |
615 | if (use_layer_norm) { |
616 | const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( |
617 | context, node, lstm::full::kInputLayerNormCoefficientsTensor); |
618 | if (use_cifg) { |
619 | TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr); |
620 | } else { |
621 | TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr); |
622 | TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1); |
623 | TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0], |
624 | n_cell); |
625 | if (is_integer) { |
626 | TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type, |
627 | kTfLiteInt16); |
628 | } else { |
629 | TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type, |
630 | kTfLiteFloat32); |
631 | } |
632 | } |
633 | |
634 | const TfLiteTensor* forget_layer_norm_coefficients; |
635 | TF_LITE_ENSURE_OK( |
636 | context, GetInputSafe(context, node, |
637 | lstm::full::kForgetLayerNormCoefficientsTensor, |
638 | &forget_layer_norm_coefficients)); |
639 | TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1); |
640 | TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0], |
641 | n_cell); |
642 | if (is_integer) { |
643 | TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type, |
644 | kTfLiteInt16); |
645 | } else { |
646 | TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type, |
647 | kTfLiteFloat32); |
648 | } |
649 | |
650 | const TfLiteTensor* cell_layer_norm_coefficients; |
651 | TF_LITE_ENSURE_OK(context, |
652 | GetInputSafe(context, node, |
653 | lstm::full::kCellLayerNormCoefficientsTensor, |
654 | &cell_layer_norm_coefficients)); |
655 | TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1); |
656 | TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0], |
657 | n_cell); |
658 | if (is_integer) { |
659 | TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type, |
660 | kTfLiteInt16); |
661 | } else { |
662 | TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type, |
663 | kTfLiteFloat32); |
664 | } |
665 | |
666 | const TfLiteTensor* output_layer_norm_coefficients; |
667 | TF_LITE_ENSURE_OK( |
668 | context, GetInputSafe(context, node, |
669 | lstm::full::kOutputLayerNormCoefficientsTensor, |
670 | &output_layer_norm_coefficients)); |
671 | TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1); |
672 | TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0], |
673 | n_cell); |
674 | if (is_integer) { |
675 | TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type, |
676 | kTfLiteInt16); |
677 | } else { |
678 | TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type, |
679 | kTfLiteFloat32); |
680 | } |
681 | } |
682 | |
683 | return kTfLiteOk; |
684 | } |
685 | |
686 | TfLiteStatus PrecomputeZeroPointTimesWeightWithBias( |
687 | TfLiteContext* context, int32_t zero_point, |
688 | const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor, |
689 | std::unique_ptr<int32_t[]>* output) { |
690 | if (weight_tensor == nullptr) { |
691 | return kTfLiteOk; |
692 | } |
693 | |
694 | const RuntimeShape& weight_shape = GetTensorShape(weight_tensor); |
695 | TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2); |
696 | const int row = weight_shape.Dims(0); |
697 | const int col = weight_shape.Dims(1); |
698 | output->reset(new int32_t[row]); |
699 | if (bias_tensor == nullptr) { |
700 | memset(output->get(), 0, row * sizeof(int32_t)); |
701 | } else { |
702 | const int32_t* bias = GetTensorData<int32_t>(bias_tensor); |
703 | memcpy(output->get(), bias, row * sizeof(int32_t)); |
704 | } |
705 | if (zero_point != 0) { |
706 | const int8_t* weight = GetTensorData<int8_t>(weight_tensor); |
707 | tensor_utils::MatrixScalarMultiplyAccumulate(weight, zero_point, row, col, |
708 | output->get()); |
709 | } |
710 | return kTfLiteOk; |
711 | } |
712 | |
713 | TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context, |
714 | OpData* op_data, |
715 | TfLiteNode* node) { |
716 | const TfLiteTensor* input; |
717 | TF_LITE_ENSURE_OK( |
718 | context, GetInputSafe(context, node, lstm::full::kInputTensor, &input)); |
719 | const TfLiteTensor* output_state = |
720 | GetVariableInput(context, node, lstm::full::kOutputStateTensor); |
721 | TF_LITE_ENSURE(context, output_state != nullptr); |
722 | |
723 | const int32_t input_zero_point = -input->params.zero_point; |
724 | const int32_t output_state_zero_point = -output_state->params.zero_point; |
725 | |
726 | const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor( |
727 | context, node, lstm::full::kInputToInputWeightsTensor); |
728 | const TfLiteTensor* input_to_forget_weights; |
729 | TF_LITE_ENSURE_OK( |
730 | context, |
731 | GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor, |
732 | &input_to_forget_weights)); |
733 | const TfLiteTensor* input_to_cell_weights; |
734 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, |
735 | lstm::full::kInputToCellWeightsTensor, |
736 | &input_to_cell_weights)); |
737 | const TfLiteTensor* input_to_output_weights; |
738 | TF_LITE_ENSURE_OK( |
739 | context, |
740 | GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor, |
741 | &input_to_output_weights)); |
742 | |
743 | const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor( |
744 | context, node, lstm::full::kRecurrentToInputWeightsTensor); |
745 | const TfLiteTensor* recurrent_to_forget_weights; |
746 | TF_LITE_ENSURE_OK( |
747 | context, |
748 | GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor, |
749 | &recurrent_to_forget_weights)); |
750 | const TfLiteTensor* recurrent_to_cell_weights; |
751 | TF_LITE_ENSURE_OK( |
752 | context, |
753 | GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor, |
754 | &recurrent_to_cell_weights)); |
755 | const TfLiteTensor* recurrent_to_output_weights; |
756 | TF_LITE_ENSURE_OK( |
757 | context, |
758 | GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor, |
759 | &recurrent_to_output_weights)); |
760 | |
761 | const TfLiteTensor* projection_weights = GetOptionalInputTensor( |
762 | context, node, lstm::full::kProjectionWeightsTensor); |
763 | const TfLiteTensor* projection_bias = |
764 | GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor); |
765 | |
766 | lstm_eval::IntegerLstmParameter* integer_lstm_params = |
767 | &op_data->integer_lstm_param; |
768 | |
769 | const TfLiteTensor* intermediate = |
770 | &context->tensors[node->intermediates->data[4]]; |
771 | TF_LITE_ENSURE(context, |
772 | intermediate->quantization.type != kTfLiteNoQuantization); |
773 | const auto* params = |
774 | static_cast<TfLiteAffineQuantization*>(intermediate->quantization.params); |
775 | const int32_t hidden_zp = params->zero_point->data[0]; |
776 | |
777 | // Get bias and perform zero point calculation. |
778 | // When there is layer normalization, the gate bias does not apply to matmul |
779 | // directly: |
780 | // y = ln(w * x + w * r + w * c) + b. |
781 | const bool is_layer_norm = op_data->use_layer_norm; |
782 | |
783 | // Forget gate. |
784 | const TfLiteTensor* forget_gate_bias = |
785 | is_layer_norm |
786 | ? nullptr |
787 | : GetInput(context, node, lstm::full::kForgetGateBiasTensor); |
788 | TF_LITE_ENSURE_OK( |
789 | context, |
790 | PrecomputeZeroPointTimesWeightWithBias( |
791 | context, input_zero_point, input_to_forget_weights, forget_gate_bias, |
792 | &(integer_lstm_params->input_to_forget_effective_bias))); |
793 | |
794 | TF_LITE_ENSURE_OK( |
795 | context, |
796 | PrecomputeZeroPointTimesWeightWithBias( |
797 | context, output_state_zero_point, recurrent_to_forget_weights, |
798 | nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias))); |
799 | |
800 | // Modulation gate. |
801 | const TfLiteTensor* cell_gate_bias = |
802 | is_layer_norm ? nullptr |
803 | : GetInput(context, node, lstm::full::kCellGateBiasTensor); |
804 | TF_LITE_ENSURE_OK( |
805 | context, |
806 | PrecomputeZeroPointTimesWeightWithBias( |
807 | context, input_zero_point, input_to_cell_weights, cell_gate_bias, |
808 | &(integer_lstm_params->input_to_cell_effective_bias))); |
809 | TF_LITE_ENSURE_OK( |
810 | context, |
811 | PrecomputeZeroPointTimesWeightWithBias( |
812 | context, output_state_zero_point, recurrent_to_cell_weights, nullptr, |
813 | &(integer_lstm_params->recurrent_to_cell_effective_bias))); |
814 | |
815 | // Output gate. |
816 | const TfLiteTensor* output_gate_bias = |
817 | is_layer_norm |
818 | ? nullptr |
819 | : GetInput(context, node, lstm::full::kOutputGateBiasTensor); |
820 | TF_LITE_ENSURE_OK( |
821 | context, |
822 | PrecomputeZeroPointTimesWeightWithBias( |
823 | context, input_zero_point, input_to_output_weights, output_gate_bias, |
824 | &(integer_lstm_params->input_to_output_effective_bias))); |
825 | |
826 | TF_LITE_ENSURE_OK( |
827 | context, |
828 | PrecomputeZeroPointTimesWeightWithBias( |
829 | context, output_state_zero_point, recurrent_to_output_weights, |
830 | nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias))); |
831 | |
832 | // Input gate. The calculation is only meaningful for non-cifg case. |
833 | const TfLiteTensor* input_gate_bias = |
834 | is_layer_norm ? nullptr |
835 | : GetInput(context, node, lstm::full::kInputGateBiasTensor); |
836 | TF_LITE_ENSURE_OK( |
837 | context, |
838 | PrecomputeZeroPointTimesWeightWithBias( |
839 | context, input_zero_point, input_to_input_weights, input_gate_bias, |
840 | &(integer_lstm_params->input_to_input_effective_bias))); |
841 | TF_LITE_ENSURE_OK( |
842 | context, |
843 | PrecomputeZeroPointTimesWeightWithBias( |
844 | context, output_state_zero_point, recurrent_to_input_weights, nullptr, |
845 | &(integer_lstm_params->recurrent_to_input_effective_bias))); |
846 | |
847 | // Projection bias. The calculation is only meaningful for with projection. |
848 | TF_LITE_ENSURE_OK(context, |
849 | PrecomputeZeroPointTimesWeightWithBias( |
850 | context, hidden_zp, projection_weights, projection_bias, |
851 | &(integer_lstm_params->projection_effective_bias))); |
852 | return kTfLiteOk; |
853 | } |
854 | |
855 | // Resize the output and state tensors based on the sizes of the input tensors. |
856 | // Allocate a temporary scratch tensor. Also check that the sizes of the input |
857 | // tensors match each other. |
858 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
859 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
860 | const int scratch_tensor_index = op_data->scratch_tensor_index; |
861 | |
862 | // Check we have all the inputs and outputs we need. |
863 | bool use_layer_norm = false; |
864 | if (node->inputs->size == 24) { |
865 | const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor( |
866 | context, node, lstm::full::kForgetLayerNormCoefficientsTensor); |
867 | if (forget_layer_norm_coefficients == nullptr) { |
868 | use_layer_norm = false; |
869 | } else { |
870 | use_layer_norm = true; |
871 | } |
872 | } else if (node->inputs->size == 20) { |
873 | // This is deprecated and is only kept here for backward compatibility. |
874 | use_layer_norm = false; |
875 | } else { |
876 | TF_LITE_KERNEL_LOG( |
877 | context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs" , |
878 | node->inputs->size); |
879 | return kTfLiteError; |
880 | } |
881 | TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); |
882 | op_data->use_layer_norm = use_layer_norm; |
883 | |
884 | // Inferring batch size, number of outputs and sequence length and |
885 | // number of cells from the input tensors. |
886 | const TfLiteTensor* input; |
887 | TF_LITE_ENSURE_OK( |
888 | context, GetInputSafe(context, node, lstm::full::kInputTensor, &input)); |
889 | const bool is_integer = input->type == kTfLiteInt8; |
890 | TF_LITE_ENSURE(context, input->dims->size > 1); |
891 | const auto* params = |
892 | reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>( |
893 | node->builtin_data); |
894 | const bool time_major = params->time_major; |
895 | const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0]; |
896 | const int n_input = input->dims->data[2]; |
897 | |
898 | const TfLiteTensor* input_to_output_weights; |
899 | TF_LITE_ENSURE_OK( |
900 | context, |
901 | GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor, |
902 | &input_to_output_weights)); |
903 | const int n_cell = input_to_output_weights->dims->data[0]; |
904 | TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); |
905 | TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); |
906 | |
907 | const TfLiteTensor* recurrent_to_output_weights; |
908 | TF_LITE_ENSURE_OK( |
909 | context, |
910 | GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor, |
911 | &recurrent_to_output_weights)); |
912 | TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); |
913 | TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], |
914 | n_cell); |
915 | const int n_output = recurrent_to_output_weights->dims->data[1]; |
916 | |
917 | // Check that input tensor dimensions matches with each other. |
918 | TF_LITE_ENSURE_OK( |
919 | context, CheckInputTensorDimensions(context, node, n_input, n_output, |
920 | n_cell, use_layer_norm, is_integer)); |
921 | |
922 | // Get the pointer to output, output_state and cell_state buffer tensors. |
923 | TfLiteTensor* output; |
924 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, |
925 | lstm::full::kOutputTensor, &output)); |
926 | |
927 | TfLiteTensor* output_state = |
928 | GetVariableInput(context, node, lstm::full::kOutputStateTensor); |
929 | TF_LITE_ENSURE(context, output_state != nullptr); |
930 | TfLiteTensor* cell_state = |
931 | GetVariableInput(context, node, lstm::full::kCellStateTensor); |
932 | TF_LITE_ENSURE(context, cell_state != nullptr); |
933 | |
934 | // Check the shape of input state tensors. |
935 | // These tensor may be 1D or 2D. It's fine as long as the total size is |
936 | // correct. |
937 | TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output); |
938 | TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); |
939 | |
940 | // Resize the output tensors. |
941 | TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims); |
942 | output_size->data[input->dims->size - 1] = n_output; |
943 | TF_LITE_ENSURE_OK(context, |
944 | context->ResizeTensor(context, output, output_size)); |
945 | |
946 | if (is_integer) { |
947 | const int num_intermediate_tensors = node->intermediates->size; |
948 | TF_LITE_ENSURE(context, num_intermediate_tensors == 5); |
949 | } |
950 | |
951 | TfLiteIntArrayFree(node->temporaries); |
952 | if (IsHybridOp(input, input_to_output_weights)) { |
953 | node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); |
954 | } else if (is_integer) { |
955 | node->temporaries = TfLiteIntArrayCreate(6); |
956 | } else { |
957 | node->temporaries = TfLiteIntArrayCreate(1); |
958 | } |
959 | node->temporaries->data[kScratchBuffer] = |
960 | scratch_tensor_index + kScratchBuffer; |
961 | |
962 | // Create a scratch buffer tensor. |
963 | TfLiteTensor* scratch_buffer; |
964 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer, |
965 | &scratch_buffer)); |
966 | scratch_buffer->type = input->type; |
967 | scratch_buffer->allocation_type = kTfLiteArenaRw; |
968 | |
969 | const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor( |
970 | context, node, lstm::full::kInputToInputWeightsTensor); |
971 | const bool use_cifg = (input_to_input_weights == nullptr); |
972 | TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); |
973 | scratch_buffer_size->data[0] = n_batch; |
974 | if (use_cifg) { |
975 | // Reserving space for Cell, Forget, Output gates and scratch accumulation |
976 | // buffer and an extra 16 bytes to avoid internal ruy copies. |
977 | scratch_buffer_size->data[1] = n_cell * 4 + 16; |
978 | } else { |
979 | // Reserving space for Input, Cell, Forget, Output gates and scratch |
980 | // accumulation buffer and an extra 16 bytes to avoid internal ruy copies. |
981 | scratch_buffer_size->data[1] = n_cell * 5 + 16; |
982 | } |
983 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, |
984 | scratch_buffer_size)); |
985 | |
986 | if (IsHybridOp(input, input_to_output_weights)) { |
987 | op_data->compute_row_sums = true; |
988 | // Allocate temporary tensors to store quantized values of input, |
989 | // output_state and cell_state tensors. |
990 | node->temporaries->data[kInputQuantized] = |
991 | scratch_tensor_index + kInputQuantized; |
992 | TfLiteTensor* input_quantized; |
993 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized, |
994 | &input_quantized)); |
995 | input_quantized->type = input_to_output_weights->type; |
996 | input_quantized->allocation_type = kTfLiteArenaRw; |
997 | if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { |
998 | TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); |
999 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, |
1000 | input_quantized_size)); |
1001 | } |
1002 | node->temporaries->data[kOutputStateQuantized] = |
1003 | scratch_tensor_index + kOutputStateQuantized; |
1004 | TfLiteTensor* output_state_quantized; |
1005 | TF_LITE_ENSURE_OK(context, |
1006 | GetTemporarySafe(context, node, kOutputStateQuantized, |
1007 | &output_state_quantized)); |
1008 | output_state_quantized->type = input_to_output_weights->type; |
1009 | output_state_quantized->allocation_type = kTfLiteArenaRw; |
1010 | if (!TfLiteIntArrayEqual(output_state_quantized->dims, |
1011 | output_state->dims)) { |
1012 | TfLiteIntArray* output_state_quantized_size = |
1013 | TfLiteIntArrayCopy(output_state->dims); |
1014 | TF_LITE_ENSURE_OK(context, |
1015 | context->ResizeTensor(context, output_state_quantized, |
1016 | output_state_quantized_size)); |
1017 | } |
1018 | node->temporaries->data[kCellStateQuantized] = |
1019 | scratch_tensor_index + kCellStateQuantized; |
1020 | TfLiteTensor* cell_state_quantized; |
1021 | TF_LITE_ENSURE_OK(context, |
1022 | GetTemporarySafe(context, node, kCellStateQuantized, |
1023 | &cell_state_quantized)); |
1024 | cell_state_quantized->type = input_to_output_weights->type; |
1025 | cell_state_quantized->allocation_type = kTfLiteArenaRw; |
1026 | if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { |
1027 | TfLiteIntArray* cell_state_quantized_size = |
1028 | TfLiteIntArrayCopy(cell_state->dims); |
1029 | TF_LITE_ENSURE_OK(context, |
1030 | context->ResizeTensor(context, cell_state_quantized, |
1031 | cell_state_quantized_size)); |
1032 | } |
1033 | |
1034 | // Allocate temporary tensors to store scaling factors and product scaling |
1035 | // factors. The latter is a convenience storage which allows to quantize |
1036 | // a vector once (which produces the scaling factors) and multiply it with |
1037 | // different matrices (which requires multiplying the scaling factors with |
1038 | // the scaling factor of the matrix). |
1039 | node->temporaries->data[kInputScalingFactors] = |
1040 | op_data->scratch_tensor_index + kInputScalingFactors; |
1041 | TfLiteTensor* input_sf; |
1042 | TF_LITE_ENSURE_OK( |
1043 | context, |
1044 | GetTemporarySafe(context, node, kInputScalingFactors, &input_sf)); |
1045 | input_sf->type = kTfLiteFloat32; |
1046 | input_sf->allocation_type = kTfLiteArenaRw; |
1047 | int scaling_dims[1] = {n_batch}; |
1048 | if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) { |
1049 | TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1); |
1050 | input_sf_size->data[0] = n_batch; |
1051 | TF_LITE_ENSURE_OK( |
1052 | context, context->ResizeTensor(context, input_sf, input_sf_size)); |
1053 | } |
1054 | node->temporaries->data[kOutputStateScalingFactors] = |
1055 | op_data->scratch_tensor_index + kOutputStateScalingFactors; |
1056 | TfLiteTensor* output_state_sf; |
1057 | TF_LITE_ENSURE_OK( |
1058 | context, GetTemporarySafe(context, node, kOutputStateScalingFactors, |
1059 | &output_state_sf)); |
1060 | output_state_sf->type = kTfLiteFloat32; |
1061 | output_state_sf->allocation_type = kTfLiteArenaRw; |
1062 | if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) { |
1063 | TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1); |
1064 | output_state_sf_size->data[0] = n_batch; |
1065 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf, |
1066 | output_state_sf_size)); |
1067 | } |
1068 | node->temporaries->data[kProductScalingFactors] = |
1069 | scratch_tensor_index + kProductScalingFactors; |
1070 | TfLiteTensor* prod_scaling_factors; |
1071 | TF_LITE_ENSURE_OK(context, |
1072 | GetTemporarySafe(context, node, kProductScalingFactors, |
1073 | &prod_scaling_factors)); |
1074 | prod_scaling_factors->type = kTfLiteFloat32; |
1075 | prod_scaling_factors->allocation_type = kTfLiteArenaRw; |
1076 | if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1, |
1077 | scaling_dims)) { |
1078 | TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); |
1079 | prod_scaling_factors_size->data[0] = n_batch; |
1080 | TF_LITE_ENSURE_OK(context, |
1081 | context->ResizeTensor(context, prod_scaling_factors, |
1082 | prod_scaling_factors_size)); |
1083 | } |
1084 | |
1085 | // Allocate a temporary tensor to store the recovered cell weights. Since |
1086 | // this is used for diagonal matrices, only need to store n_cell values. |
1087 | node->temporaries->data[kRecoveredCellWeights] = |
1088 | scratch_tensor_index + kRecoveredCellWeights; |
1089 | TfLiteTensor* recovered_cell_weights; |
1090 | TF_LITE_ENSURE_OK(context, |
1091 | GetTemporarySafe(context, node, kRecoveredCellWeights, |
1092 | &recovered_cell_weights)); |
1093 | recovered_cell_weights->type = kTfLiteFloat32; |
1094 | recovered_cell_weights->allocation_type = kTfLiteArenaRw; |
1095 | int recovered_cell_dims[1] = {n_cell}; |
1096 | if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1, |
1097 | recovered_cell_dims)) { |
1098 | TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); |
1099 | recovered_cell_weights_size->data[0] = n_cell; |
1100 | TF_LITE_ENSURE_OK(context, |
1101 | context->ResizeTensor(context, recovered_cell_weights, |
1102 | recovered_cell_weights_size)); |
1103 | } |
1104 | |
1105 | // Allocate a temporary tensor to store the accumulated int32 values. |
1106 | node->temporaries->data[kAccumScratch] = |
1107 | scratch_tensor_index + kAccumScratch; |
1108 | TfLiteTensor* accum_scratch; |
1109 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch, |
1110 | &accum_scratch)); |
1111 | accum_scratch->type = kTfLiteInt32; |
1112 | accum_scratch->allocation_type = kTfLiteArenaRw; |
1113 | int accum_scratch_dims[2] = {n_cell, n_batch}; |
1114 | if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, |
1115 | accum_scratch_dims)) { |
1116 | TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2); |
1117 | accum_size->data[0] = n_cell; |
1118 | accum_size->data[1] = n_batch; |
1119 | TF_LITE_ENSURE_OK( |
1120 | context, context->ResizeTensor(context, accum_scratch, accum_size)); |
1121 | } |
1122 | node->temporaries->data[kInputZeroPoints] = |
1123 | op_data->scratch_tensor_index + kInputZeroPoints; |
1124 | TfLiteTensor* input_zp; |
1125 | TF_LITE_ENSURE_OK( |
1126 | context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp)); |
1127 | input_zp->type = kTfLiteFloat32; |
1128 | input_zp->allocation_type = kTfLiteArenaRw; |
1129 | if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) { |
1130 | TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1); |
1131 | input_zp_size->data[0] = n_batch; |
1132 | TF_LITE_ENSURE_OK( |
1133 | context, context->ResizeTensor(context, input_zp, input_zp_size)); |
1134 | } |
1135 | node->temporaries->data[kOutputStateZeroPoints] = |
1136 | op_data->scratch_tensor_index + kOutputStateZeroPoints; |
1137 | TfLiteTensor* output_state_zp; |
1138 | TF_LITE_ENSURE_OK(context, |
1139 | GetTemporarySafe(context, node, kOutputStateZeroPoints, |
1140 | &output_state_zp)); |
1141 | output_state_zp->type = kTfLiteFloat32; |
1142 | output_state_zp->allocation_type = kTfLiteArenaRw; |
1143 | if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) { |
1144 | TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1); |
1145 | output_state_zp_size->data[0] = n_batch; |
1146 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp, |
1147 | output_state_zp_size)); |
1148 | } |
1149 | node->temporaries->data[kRowSums] = scratch_tensor_index + kRowSums; |
1150 | TfLiteTensor* row_sums; |
1151 | TF_LITE_ENSURE_OK(context, |
1152 | GetTemporarySafe(context, node, kRowSums, &row_sums)); |
1153 | row_sums->type = kTfLiteInt32; |
1154 | row_sums->name = "Lstm_row_sums" ; |
1155 | row_sums->allocation_type = kTfLiteArenaRwPersistent; |
1156 | int row_sums_rows = use_cifg ? 6 : 8; |
1157 | const TfLiteTensor* projection_weights = GetOptionalInputTensor( |
1158 | context, node, lstm::full::kProjectionWeightsTensor); |
1159 | if (projection_weights != nullptr) { |
1160 | row_sums_rows += ceil(static_cast<float>(n_output) / n_cell); |
1161 | } |
1162 | int row_sums_dims[2] = {row_sums_rows, n_cell}; |
1163 | if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) { |
1164 | TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2); |
1165 | row_sums_size->data[0] = row_sums_dims[0]; |
1166 | row_sums_size->data[1] = row_sums_dims[1]; |
1167 | TF_LITE_ENSURE_OK( |
1168 | context, context->ResizeTensor(context, row_sums, row_sums_size)); |
1169 | } |
1170 | } |
1171 | |
1172 | if (is_integer) { |
1173 | // Integer UnidirectionalSequenceLSTM prepare function for 8x8->16. |
1174 | // This code path needs 5 intermediate tensors per Op. |
1175 | // Populate quantization parameters. |
1176 | PopulateQuantizedLstmParams8x8_16(context, node, |
1177 | &op_data->integer_lstm_param); |
1178 | // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell |
1179 | // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit |
1180 | // buffer with size n_batch * n_cell. |
1181 | // |
1182 | // Handle cifg case as well, which might save one buffer. |
1183 | for (int scratch_index = 0; scratch_index < 6; ++scratch_index) { |
1184 | node->temporaries->data[scratch_index] = |
1185 | op_data->scratch_tensor_index + scratch_index; |
1186 | TfLiteTensor* scratch_tensor; |
1187 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, scratch_index, |
1188 | &scratch_tensor)); |
1189 | |
1190 | scratch_tensor->type = kTfLiteInt16; |
1191 | if (scratch_index == 4) { |
1192 | scratch_tensor->type = kTfLiteInt8; |
1193 | } else if (scratch_index == 5) { |
1194 | scratch_tensor->type = kTfLiteInt32; |
1195 | } |
1196 | |
1197 | scratch_tensor->allocation_type = kTfLiteArenaRw; |
1198 | const int scratch_dimension[2] = {n_batch, n_cell}; |
1199 | if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2, |
1200 | scratch_dimension)) { |
1201 | TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); |
1202 | scratch_buffer_size->data[0] = n_batch; |
1203 | scratch_buffer_size->data[1] = n_cell; |
1204 | TF_LITE_ENSURE_OK(context, |
1205 | context->ResizeTensor(context, scratch_tensor, |
1206 | scratch_buffer_size)); |
1207 | } |
1208 | } |
1209 | |
1210 | // Populate precomputed zp * weight. |
1211 | TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias( |
1212 | context, op_data, node)); |
1213 | } |
1214 | |
1215 | return kTfLiteOk; |
1216 | } |
1217 | |
1218 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
1219 | const auto* params = |
1220 | reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>( |
1221 | node->builtin_data); |
1222 | const OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
1223 | const bool use_layer_norm = op_data->use_layer_norm; |
1224 | const bool time_major = params->time_major; |
1225 | const TfLiteTensor* input; |
1226 | TF_LITE_ENSURE_OK( |
1227 | context, GetInputSafe(context, node, lstm::full::kInputTensor, &input)); |
1228 | |
1229 | const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor( |
1230 | context, node, lstm::full::kInputToInputWeightsTensor); |
1231 | const TfLiteTensor* input_to_forget_weights; |
1232 | TF_LITE_ENSURE_OK( |
1233 | context, |
1234 | GetInputSafe(context, node, lstm::full::kInputToForgetWeightsTensor, |
1235 | &input_to_forget_weights)); |
1236 | const TfLiteTensor* input_to_cell_weights; |
1237 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, |
1238 | lstm::full::kInputToCellWeightsTensor, |
1239 | &input_to_cell_weights)); |
1240 | const TfLiteTensor* input_to_output_weights; |
1241 | TF_LITE_ENSURE_OK( |
1242 | context, |
1243 | GetInputSafe(context, node, lstm::full::kInputToOutputWeightsTensor, |
1244 | &input_to_output_weights)); |
1245 | |
1246 | const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor( |
1247 | context, node, lstm::full::kRecurrentToInputWeightsTensor); |
1248 | const TfLiteTensor* recurrent_to_forget_weights; |
1249 | TF_LITE_ENSURE_OK( |
1250 | context, |
1251 | GetInputSafe(context, node, lstm::full::kRecurrentToForgetWeightsTensor, |
1252 | &recurrent_to_forget_weights)); |
1253 | const TfLiteTensor* recurrent_to_cell_weights; |
1254 | TF_LITE_ENSURE_OK( |
1255 | context, |
1256 | GetInputSafe(context, node, lstm::full::kRecurrentToCellWeightsTensor, |
1257 | &recurrent_to_cell_weights)); |
1258 | const TfLiteTensor* recurrent_to_output_weights; |
1259 | TF_LITE_ENSURE_OK( |
1260 | context, |
1261 | GetInputSafe(context, node, lstm::full::kRecurrentToOutputWeightsTensor, |
1262 | &recurrent_to_output_weights)); |
1263 | |
1264 | const TfLiteTensor* cell_to_input_weights = GetOptionalInputTensor( |
1265 | context, node, lstm::full::kCellToInputWeightsTensor); |
1266 | const TfLiteTensor* cell_to_forget_weights = GetOptionalInputTensor( |
1267 | context, node, lstm::full::kCellToForgetWeightsTensor); |
1268 | const TfLiteTensor* cell_to_output_weights = GetOptionalInputTensor( |
1269 | context, node, lstm::full::kCellToOutputWeightsTensor); |
1270 | |
1271 | const TfLiteTensor* input_gate_bias = |
1272 | GetOptionalInputTensor(context, node, lstm::full::kInputGateBiasTensor); |
1273 | const TfLiteTensor* forget_gate_bias; |
1274 | TF_LITE_ENSURE_OK( |
1275 | context, GetInputSafe(context, node, lstm::full::kForgetGateBiasTensor, |
1276 | &forget_gate_bias)); |
1277 | const TfLiteTensor* cell_gate_bias; |
1278 | TF_LITE_ENSURE_OK(context, |
1279 | GetInputSafe(context, node, lstm::full::kCellGateBiasTensor, |
1280 | &cell_gate_bias)); |
1281 | const TfLiteTensor* output_gate_bias; |
1282 | TF_LITE_ENSURE_OK( |
1283 | context, GetInputSafe(context, node, lstm::full::kOutputGateBiasTensor, |
1284 | &output_gate_bias)); |
1285 | |
1286 | const TfLiteTensor* projection_weights = GetOptionalInputTensor( |
1287 | context, node, lstm::full::kProjectionWeightsTensor); |
1288 | const TfLiteTensor* projection_bias = |
1289 | GetOptionalInputTensor(context, node, lstm::full::kProjectionBiasTensor); |
1290 | |
1291 | TfLiteTensor* output_state = |
1292 | GetVariableInput(context, node, lstm::full::kOutputStateTensor); |
1293 | TFLITE_DCHECK(output_state != nullptr); |
1294 | TfLiteTensor* cell_state = |
1295 | GetVariableInput(context, node, lstm::full::kCellStateTensor); |
1296 | TFLITE_DCHECK(cell_state != nullptr); |
1297 | |
1298 | const TfLiteTensor* input_layer_norm_coefficients = |
1299 | use_layer_norm |
1300 | ? GetOptionalInputTensor( |
1301 | context, node, lstm::full::kInputLayerNormCoefficientsTensor) |
1302 | : nullptr; |
1303 | const TfLiteTensor* forget_layer_norm_coefficients = |
1304 | use_layer_norm ? GetInput(context, node, |
1305 | lstm::full::kForgetLayerNormCoefficientsTensor) |
1306 | : nullptr; |
1307 | const TfLiteTensor* cell_layer_norm_coefficients = |
1308 | use_layer_norm ? GetInput(context, node, |
1309 | lstm::full::kCellLayerNormCoefficientsTensor) |
1310 | : nullptr; |
1311 | const TfLiteTensor* output_layer_norm_coefficients = |
1312 | use_layer_norm ? GetInput(context, node, |
1313 | lstm::full::kOutputLayerNormCoefficientsTensor) |
1314 | : nullptr; |
1315 | |
1316 | TfLiteTensor* output; |
1317 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, |
1318 | lstm::full::kOutputTensor, &output)); |
1319 | |
1320 | // Copy out the LSTM specific params so they can be passed in the function. |
1321 | TfLiteLSTMParams lstm_params; |
1322 | lstm_params.activation = params->activation; |
1323 | lstm_params.cell_clip = params->cell_clip; |
1324 | lstm_params.proj_clip = params->proj_clip; |
1325 | lstm_params.asymmetric_quantize_inputs = params->asymmetric_quantize_inputs; |
1326 | |
1327 | switch (input_to_output_weights->type) { |
1328 | case kTfLiteFloat32: { |
1329 | // Index the scratch buffers pointers to the global scratch buffer. |
1330 | TfLiteTensor* scratch_buffer; |
1331 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer, |
1332 | &scratch_buffer)); |
1333 | return lstm_eval::EvalFloat( |
1334 | input, input_to_input_weights, input_to_forget_weights, |
1335 | input_to_cell_weights, input_to_output_weights, |
1336 | recurrent_to_input_weights, recurrent_to_forget_weights, |
1337 | recurrent_to_cell_weights, recurrent_to_output_weights, |
1338 | cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, |
1339 | input_layer_norm_coefficients, forget_layer_norm_coefficients, |
1340 | cell_layer_norm_coefficients, output_layer_norm_coefficients, |
1341 | /*aux_input=*/nullptr, |
1342 | /*aux_input_to_input_weights=*/nullptr, |
1343 | /*aux_input_to_forget_weights=*/nullptr, |
1344 | /*aux_input_to_cell_weights=*/nullptr, |
1345 | /*aux_input_to_output_weights=*/nullptr, input_gate_bias, |
1346 | forget_gate_bias, cell_gate_bias, output_gate_bias, |
1347 | projection_weights, projection_bias, &lstm_params, |
1348 | /*forward_sequence=*/true, time_major, |
1349 | /*output_offset=*/0, scratch_buffer, output_state, cell_state, output, |
1350 | CpuBackendContext::GetFromContext(context)); |
1351 | } |
1352 | case kTfLiteUInt8: |
1353 | case kTfLiteInt8: { |
1354 | const bool is_hybrid = input->type == kTfLiteFloat32; |
1355 | if (is_hybrid) { |
1356 | // Index the scratch buffers pointers to the global scratch buffer. |
1357 | TfLiteTensor* scratch_buffer; |
1358 | TF_LITE_ENSURE_OK( |
1359 | context, |
1360 | GetTemporarySafe(context, node, kScratchBuffer, &scratch_buffer)); |
1361 | |
1362 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
1363 | TfLiteTensor* row_sums; |
1364 | TF_LITE_ENSURE_OK(context, |
1365 | GetTemporarySafe(context, node, kRowSums, &row_sums)); |
1366 | const int row_sums_size = row_sums->dims->data[0]; |
1367 | return lstm_eval::EvalHybrid( |
1368 | input, input_to_input_weights, |
1369 | /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights, |
1370 | /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights, |
1371 | /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights, |
1372 | /*input_to_output_weights_ledger*/ nullptr, |
1373 | recurrent_to_input_weights, |
1374 | /*recurrent_to_input_weights_ledger*/ nullptr, |
1375 | recurrent_to_forget_weights, |
1376 | /*recurrent_to_forget_weights_ledger*/ nullptr, |
1377 | recurrent_to_cell_weights, |
1378 | /*recurrent_to_cell_weights_ledger*/ nullptr, |
1379 | recurrent_to_output_weights, |
1380 | /*recurrent_to_output_weights_ledger*/ nullptr, |
1381 | cell_to_input_weights, cell_to_forget_weights, |
1382 | cell_to_output_weights, input_layer_norm_coefficients, |
1383 | forget_layer_norm_coefficients, cell_layer_norm_coefficients, |
1384 | output_layer_norm_coefficients, |
1385 | /*aux_input=*/nullptr, |
1386 | /*aux_input_to_input_weights=*/nullptr, |
1387 | /*aux_input_to_forget_weights=*/nullptr, |
1388 | /*aux_input_to_cell_weights=*/nullptr, |
1389 | /*aux_input_to_output_weights=*/nullptr, input_gate_bias, |
1390 | forget_gate_bias, cell_gate_bias, output_gate_bias, |
1391 | projection_weights, /*projection_weights_ledger*/ nullptr, |
1392 | projection_bias, &lstm_params, |
1393 | /*forward_sequence=*/true, time_major, |
1394 | /*output_offset=*/0, scratch_buffer, |
1395 | GetTemporary(context, node, kInputScalingFactors), |
1396 | /*aux_input_sf=*/nullptr, |
1397 | GetTemporary(context, node, kOutputStateScalingFactors), |
1398 | GetTemporary(context, node, kProductScalingFactors), |
1399 | GetTemporary(context, node, kRecoveredCellWeights), |
1400 | GetTemporary(context, node, kInputQuantized), |
1401 | /*aux_input_quantized=*/nullptr, |
1402 | GetTemporary(context, node, kOutputStateQuantized), |
1403 | GetTemporary(context, node, kCellStateQuantized), output_state, |
1404 | cell_state, GetTemporary(context, node, kAccumScratch), output, |
1405 | GetTemporary(context, node, kInputZeroPoints), |
1406 | /*aux_input_zp=*/nullptr, |
1407 | GetTemporary(context, node, kOutputStateZeroPoints), row_sums, |
1408 | row_sums_size, &op_data->compute_row_sums, |
1409 | CpuBackendContext::GetFromContext(context)); |
1410 | } else { |
1411 | TfLiteTensor* scratch0; |
1412 | TF_LITE_ENSURE_OK(context, |
1413 | GetTemporarySafe(context, node, 0, &scratch0)); |
1414 | TfLiteTensor* scratch1; |
1415 | TF_LITE_ENSURE_OK(context, |
1416 | GetTemporarySafe(context, node, 1, &scratch1)); |
1417 | TfLiteTensor* scratch2; |
1418 | TF_LITE_ENSURE_OK(context, |
1419 | GetTemporarySafe(context, node, 2, &scratch2)); |
1420 | TfLiteTensor* scratch3; |
1421 | TF_LITE_ENSURE_OK(context, |
1422 | GetTemporarySafe(context, node, 3, &scratch3)); |
1423 | TfLiteTensor* scratch4; |
1424 | TF_LITE_ENSURE_OK(context, |
1425 | GetTemporarySafe(context, node, 4, &scratch4)); |
1426 | TfLiteTensor* scratch5; |
1427 | TF_LITE_ENSURE_OK(context, |
1428 | GetTemporarySafe(context, node, 5, &scratch5)); |
1429 | return lstm_eval::EvalInteger8x8_16( |
1430 | input, input_to_input_weights, input_to_forget_weights, |
1431 | input_to_cell_weights, input_to_output_weights, |
1432 | recurrent_to_input_weights, recurrent_to_forget_weights, |
1433 | recurrent_to_cell_weights, recurrent_to_output_weights, |
1434 | cell_to_input_weights, cell_to_forget_weights, |
1435 | cell_to_output_weights, input_layer_norm_coefficients, |
1436 | forget_layer_norm_coefficients, cell_layer_norm_coefficients, |
1437 | output_layer_norm_coefficients, input_gate_bias, forget_gate_bias, |
1438 | cell_gate_bias, output_gate_bias, projection_weights, |
1439 | projection_bias, &lstm_params, /*forward_sequence=*/true, |
1440 | time_major, &op_data->integer_lstm_param, output_state, cell_state, |
1441 | output, scratch0, scratch1, scratch2, scratch3, scratch4, scratch5, |
1442 | CpuBackendContext::GetFromContext(context)); |
1443 | } |
1444 | } |
1445 | default: |
1446 | TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported." , |
1447 | TfLiteTypeGetName(input_to_output_weights->type)); |
1448 | return kTfLiteError; |
1449 | } |
1450 | } |
1451 | } // namespace unidirectional_sequence_lstm |
1452 | |
1453 | TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() { |
1454 | static TfLiteRegistration r = {unidirectional_sequence_lstm::Init, |
1455 | unidirectional_sequence_lstm::Free, |
1456 | unidirectional_sequence_lstm::Prepare, |
1457 | unidirectional_sequence_lstm::Eval}; |
1458 | return &r; |
1459 | } |
1460 | |
1461 | } // namespace builtin |
1462 | } // namespace ops |
1463 | } // namespace tflite |
1464 | |