1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <math.h>
17
18#include <algorithm>
19#include <cstddef>
20
21#include "tensorflow/lite/c/builtin_op_data.h"
22#include "tensorflow/lite/c/common.h"
23#include "tensorflow/lite/kernels/cpu_backend_context.h"
24#include "tensorflow/lite/kernels/internal/compatibility.h"
25#include "tensorflow/lite/kernels/internal/kernel_utils.h"
26#include "tensorflow/lite/kernels/internal/tensor_utils.h"
27#include "tensorflow/lite/kernels/kernel_util.h"
28#include "tensorflow/lite/kernels/lstm_eval.h"
29#include "tensorflow/lite/kernels/op_macros.h"
30
31namespace tflite {
32namespace ops {
33namespace builtin {
34namespace bidirectional_sequence_lstm {
35
36// LINT.IfChange
37
38// Input Tensors of size {max_time, n_batch, n_input}
39constexpr int kInputTensor = 0;
40
41// Forward LSTM cell tensors.
42// Input weight tensors of size: {n_cell, n_input}
43constexpr int kFwInputToInputWeightsTensor = 1; // Optional
44constexpr int kFwInputToForgetWeightsTensor = 2;
45constexpr int kFwInputToCellWeightsTensor = 3;
46constexpr int kFwInputToOutputWeightsTensor = 4;
47
48// Recurrent weight tensors of size {n_cell, n_output}
49constexpr int kFwRecurrentToInputWeightsTensor = 5; // Optional
50constexpr int kFwRecurrentToForgetWeightsTensor = 6;
51constexpr int kFwRecurrentToCellWeightsTensor = 7;
52constexpr int kFwRecurrentToOutputWeightsTensor = 8;
53
54// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
55constexpr int kFwCellToInputWeightsTensor = 9; // Optional
56constexpr int kFwCellToForgetWeightsTensor = 10; // Optional
57constexpr int kFwCellToOutputWeightsTensor = 11; // Optional
58
59// Gates bias tensors of size {n_cell}
60constexpr int kFwInputGateBiasTensor = 12; // Optional
61constexpr int kFwForgetGateBiasTensor = 13;
62constexpr int kFwCellGateBiasTensor = 14;
63constexpr int kFwOutputGateBiasTensor = 15;
64
65// Projection weight tensor of size {n_output, n_cell}
66constexpr int kFwProjectionWeightsTensor = 16; // Optional
67// Projection bias tensor of size {n_output}
68constexpr int kFwProjectionBiasTensor = 17; // Optional
69
70// Backward LSTM cell tensors.
71// Input weight tensors of size: {n_cell, n_input}
72constexpr int kBwInputToInputWeightsTensor = 18; // Optional
73constexpr int kBwInputToForgetWeightsTensor = 19;
74constexpr int kBwInputToCellWeightsTensor = 20;
75constexpr int kBwInputToOutputWeightsTensor = 21;
76
77// Recurrent weight tensors of size {n_cell, n_output}
78constexpr int kBwRecurrentToInputWeightsTensor = 22; // Optional
79constexpr int kBwRecurrentToForgetWeightsTensor = 23;
80constexpr int kBwRecurrentToCellWeightsTensor = 24;
81constexpr int kBwRecurrentToOutputWeightsTensor = 25;
82
83// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
84constexpr int kBwCellToInputWeightsTensor = 26; // Optional
85constexpr int kBwCellToForgetWeightsTensor = 27; // Optional
86constexpr int kBwCellToOutputWeightsTensor = 28; // Optional
87
88// Gates bias tensors of size {n_cell}
89constexpr int kBwInputGateBiasTensor = 29; // Optional
90constexpr int kBwForgetGateBiasTensor = 30;
91constexpr int kBwCellGateBiasTensor = 31;
92constexpr int kBwOutputGateBiasTensor = 32;
93
94// Projection weight tensor of size {n_output, n_cell}
95constexpr int kBwProjectionWeightsTensor = 33; // Optional
96// Projection bias tensor of size {n_output}
97constexpr int kBwProjectionBiasTensor = 34; // Optional
98
99// Stateful input tensors that are variables and will be modified by the Op.
100// Activation state tensors of size {n_batch, n_output}
101constexpr int kFwInputActivationStateTensor = 35;
102// Cell state tensors of size {n_batch, n_cell}
103constexpr int kFwInputCellStateTensor = 36;
104// Activation state tensors of size {n_batch, n_output}
105constexpr int kBwInputActivationStateTensor = 37;
106// Cell state tensors of size {n_batch, n_cell}
107constexpr int kBwInputCellStateTensor = 38;
108
109// Used as auxiliary input and weights when stacking for
110// tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
111// to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
112// (without cross links).
113constexpr int kAuxInputTensor = 39; // Optional
114// Forward weights.
115constexpr int kFwAuxInputToInputWeightsTensor = 40; // Optional
116constexpr int kFwAuxInputToForgetWeightsTensor = 41; // Optional
117constexpr int kFwAuxInputToCellWeightsTensor = 42; // Optional
118constexpr int kFwAuxInputToOutputWeightsTensor = 43; // Optional
119// Backward weights.
120constexpr int kBwAuxInputToInputWeightsTensor = 44; // Optional
121constexpr int kBwAuxInputToForgetWeightsTensor = 45; // Optional
122constexpr int kBwAuxInputToCellWeightsTensor = 46; // Optional
123constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional
124
125// Output tensors.
126constexpr int kFwOutputTensor = 0;
127constexpr int kBwOutputTensor = 1; // Ignored if merge_outputs is set.
128
129// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc)
130
131// Temporary tensors.
132enum TemporaryTensor {
133 // Scratch buffers for input, forget, etc. gates
134 kFwScratchBuffer = 0,
135 kBwScratchBuffer = 1,
136 // Quantized tensors needed for the hybrid kernel.
137 kInputQuantized = 2,
138 kFwActivationStateQuantized = 3,
139 kBwActivationStateQuantized = 4,
140 kFwCellStateQuantized = 5,
141 kBwCellStateQuantized = 6,
142 kInputScalingFactors = 7,
143 kAuxInputScalingFactors = 8,
144 kOutputStateScalingFactors = 9,
145 kProductScalingFactors = 10,
146 kRecoveredCellWeights = 11,
147 kAccumScratchBuffer = 12,
148 kInputZeroPoints = 13,
149 kAuxInputZeroPoints = 14,
150 kOutputStateZeroPoints = 15,
151 kFwRowSums = 16,
152 kBwRowSums = 17,
153 kAuxInputQuantized = 18, // Optional, quantized tensor for auxiliary input.
154 kNumTemporaryTensors = 19,
155};
156
157struct OpData {
158 int scratch_tensor_index;
159 bool compute_fw_row_sums = false;
160 bool compute_bw_row_sums = false;
161};
162
163void* Init(TfLiteContext* context, const char* buffer, size_t length) {
164 auto* op_data = new OpData();
165 context->AddTensors(context, kNumTemporaryTensors,
166 &op_data->scratch_tensor_index);
167 return op_data;
168}
169
170void Free(TfLiteContext* context, void* buffer) {
171 delete reinterpret_cast<OpData*>(buffer);
172}
173
174// Check that input tensor dimensions matches with each other.
175TfLiteStatus CheckLstmTensorDimensionsAndTypes(
176 TfLiteContext* context, TfLiteNode* node, int n_input, int n_output,
177 int n_cell, int input_to_input_weights_tensor,
178 int input_to_forget_weights_tensor, int input_to_cell_weights_tensor,
179 int input_to_output_weights_tensor, int recurrent_to_input_weights_tensor,
180 int recurrent_to_forget_weights_tensor,
181 int recurrent_to_cell_weights_tensor,
182 int recurrent_to_output_weights_tensor, int cell_to_input_weights_tensor,
183 int cell_to_forget_weights_tensor, int cell_to_output_weights_tensor,
184 int input_gate_bias_tensor, int forget_gate_bias_tensor,
185 int cell_gate_bias_tensor, int output_gate_bias_tensor,
186 int projection_weights_tensor, int projection_bias_tensor) {
187 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
188 node->builtin_data);
189
190 // Making sure clipping parameters have valid values.
191 // == 0 means no clipping
192 // > 0 means clipping
193 TF_LITE_ENSURE(context, params->cell_clip >= 0);
194 TF_LITE_ENSURE(context, params->proj_clip >= 0);
195
196 const TfLiteTensor* input_to_forget_weights;
197 TF_LITE_ENSURE_OK(context,
198 GetInputSafe(context, node, input_to_forget_weights_tensor,
199 &input_to_forget_weights));
200 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
201 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
202 TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
203 TF_LITE_ENSURE(context, (input_to_forget_weights->type == kTfLiteFloat32) ||
204 (input_to_forget_weights->type == kTfLiteInt8) ||
205 (input_to_forget_weights->type == kTfLiteUInt8));
206
207 const TfLiteTensor* input_to_input_weights =
208 GetOptionalInputTensor(context, node, input_to_input_weights_tensor);
209 if (input_to_input_weights != nullptr) {
210 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
211 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
212 TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
213 TF_LITE_ENSURE_TYPES_EQ(context, input_to_input_weights->type,
214 input_to_forget_weights->type);
215 }
216
217 const TfLiteTensor* input_to_cell_weights;
218 TF_LITE_ENSURE_OK(context,
219 GetInputSafe(context, node, input_to_cell_weights_tensor,
220 &input_to_cell_weights));
221 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
222 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
223 TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
224 TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type,
225 input_to_forget_weights->type);
226
227 const TfLiteTensor* input_to_output_weights;
228 TF_LITE_ENSURE_OK(context,
229 GetInputSafe(context, node, input_to_output_weights_tensor,
230 &input_to_output_weights));
231 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
232 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[0], n_cell);
233 TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
234 TF_LITE_ENSURE_TYPES_EQ(context, input_to_output_weights->type,
235 input_to_forget_weights->type);
236
237 const TfLiteTensor* recurrent_to_input_weights =
238 GetOptionalInputTensor(context, node, recurrent_to_input_weights_tensor);
239 if (recurrent_to_input_weights != nullptr) {
240 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
241 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
242 n_cell);
243 TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
244 n_output);
245 TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_input_weights->type,
246 input_to_forget_weights->type);
247 }
248
249 const TfLiteTensor* recurrent_to_forget_weights;
250 TF_LITE_ENSURE_OK(
251 context, GetInputSafe(context, node, recurrent_to_forget_weights_tensor,
252 &recurrent_to_forget_weights));
253 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
254 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
255 n_cell);
256 TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
257 n_output);
258 TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type,
259 input_to_forget_weights->type);
260
261 const TfLiteTensor* recurrent_to_cell_weights;
262 TF_LITE_ENSURE_OK(
263 context, GetInputSafe(context, node, recurrent_to_cell_weights_tensor,
264 &recurrent_to_cell_weights));
265 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
266 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
267 TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
268 n_output);
269 TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_cell_weights->type,
270 input_to_forget_weights->type);
271
272 // We make sure the input-gate's parameters are either both present (regular
273 // LSTM) or not at all (CIFG-LSTM).
274 const bool cifg_weights_all_or_none =
275 ((input_to_input_weights != nullptr) &&
276 (recurrent_to_input_weights != nullptr)) ||
277 ((input_to_input_weights == nullptr) &&
278 (recurrent_to_input_weights == nullptr));
279 TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
280
281 const TfLiteTensor* cell_to_input_weights =
282 GetOptionalInputTensor(context, node, cell_to_input_weights_tensor);
283 if (cell_to_input_weights != nullptr) {
284 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
285 TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
286 TF_LITE_ENSURE_TYPES_EQ(context, cell_to_input_weights->type,
287 input_to_forget_weights->type);
288 }
289
290 const TfLiteTensor* cell_to_forget_weights =
291 GetOptionalInputTensor(context, node, cell_to_forget_weights_tensor);
292 if (cell_to_forget_weights != nullptr) {
293 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
294 TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
295 TF_LITE_ENSURE_TYPES_EQ(context, cell_to_forget_weights->type,
296 input_to_forget_weights->type);
297 }
298
299 const TfLiteTensor* cell_to_output_weights =
300 GetOptionalInputTensor(context, node, cell_to_output_weights_tensor);
301 if (cell_to_output_weights != nullptr) {
302 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
303 TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
304 TF_LITE_ENSURE_TYPES_EQ(context, cell_to_output_weights->type,
305 input_to_forget_weights->type);
306 }
307
308 // Making sure the peephole weights are there all or none.
309 const bool use_cifg = (input_to_input_weights == nullptr);
310 const bool peephole_weights_all_or_none =
311 ((cell_to_input_weights != nullptr || use_cifg) &&
312 (cell_to_forget_weights != nullptr) &&
313 (cell_to_output_weights != nullptr)) ||
314 ((cell_to_input_weights == nullptr) &&
315 (cell_to_forget_weights == nullptr) &&
316 (cell_to_output_weights == nullptr));
317 TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
318
319 // Make sure the input gate bias is present only when not a CIFG-LSTM.
320 const TfLiteTensor* input_gate_bias =
321 GetOptionalInputTensor(context, node, input_gate_bias_tensor);
322 if (use_cifg) {
323 TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
324 } else {
325 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
326 TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
327 TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32);
328 }
329
330 const TfLiteTensor* forget_gate_bias;
331 TF_LITE_ENSURE_OK(
332 context,
333 GetInputSafe(context, node, forget_gate_bias_tensor, &forget_gate_bias));
334 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
335 TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
336 TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32);
337
338 const TfLiteTensor* cell_gate_bias;
339 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, cell_gate_bias_tensor,
340 &cell_gate_bias));
341 TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1);
342 TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell);
343 TF_LITE_ENSURE_EQ(context, cell_gate_bias->type, kTfLiteFloat32);
344
345 const TfLiteTensor* output_gate_bias;
346 TF_LITE_ENSURE_OK(
347 context,
348 GetInputSafe(context, node, output_gate_bias_tensor, &output_gate_bias));
349 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
350 TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
351 TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32);
352
353 const TfLiteTensor* projection_weights =
354 GetOptionalInputTensor(context, node, projection_weights_tensor);
355 if (projection_weights != nullptr) {
356 TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
357 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
358 TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
359 TF_LITE_ENSURE_TYPES_EQ(context, projection_weights->type,
360 input_to_forget_weights->type);
361 }
362
363 const TfLiteTensor* projection_bias =
364 GetOptionalInputTensor(context, node, projection_bias_tensor);
365 if (projection_bias != nullptr) {
366 TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
367 TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
368 TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32);
369 }
370
371 // Making sure the projection tensors are consistent:
372 // 1) If projection weight is not present, then projection bias should not be
373 // present.
374 // 2) If projection weight is present, then projection bias is optional.
375 // TODO(ghodrat): make sure this is correct.
376 const bool projecton_tensors_consistent =
377 ((projection_weights != nullptr) || (projection_bias == nullptr));
378 TF_LITE_ENSURE(context, projecton_tensors_consistent == true);
379
380 return kTfLiteOk;
381}
382
383TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
384 TfLiteNode* node, int n_input,
385 int n_output, int n_cell) {
386 TF_LITE_ENSURE_OK(
387 context,
388 CheckLstmTensorDimensionsAndTypes(
389 context, node, n_input, n_output, n_cell,
390 kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor,
391 kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor,
392 kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor,
393 kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor,
394 kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor,
395 kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor,
396 kFwForgetGateBiasTensor, kFwCellGateBiasTensor,
397 kFwOutputGateBiasTensor, kFwProjectionWeightsTensor,
398 kFwProjectionBiasTensor));
399
400 TF_LITE_ENSURE_OK(
401 context,
402 CheckLstmTensorDimensionsAndTypes(
403 context, node, n_input, n_output, n_cell,
404 kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor,
405 kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor,
406 kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor,
407 kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor,
408 kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor,
409 kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor,
410 kBwForgetGateBiasTensor, kBwCellGateBiasTensor,
411 kBwOutputGateBiasTensor, kBwProjectionWeightsTensor,
412 kBwProjectionBiasTensor));
413
414 // Check if Forward and Backward tensors match along required dimensions.
415 return kTfLiteOk;
416}
417
418// Resize the output and scratch tensors based on the sizes of the input
419// tensors. Also check that the size of the input tensors match each other.
420TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
421 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
422 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
423 node->builtin_data);
424
425 // Check we have all the inputs and outputs we need.
426 TF_LITE_ENSURE_EQ(context, node->inputs->size, 48);
427 TF_LITE_ENSURE_EQ(context, node->outputs->size,
428 params->merge_outputs ? 1 : 2);
429
430 // Inferring batch size, number of outputs and sequence length and
431 // number of cells from the input tensors.
432 const TfLiteTensor* input;
433 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
434 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
435 TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
436 const bool time_major = params->time_major;
437 const int max_time = time_major ? input->dims->data[0] : input->dims->data[1];
438 const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0];
439 const int n_input = input->dims->data[2];
440
441 const TfLiteTensor* fw_input_to_output_weights;
442 TF_LITE_ENSURE_OK(context,
443 GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
444 &fw_input_to_output_weights));
445 const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
446 TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2);
447 TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
448 n_input);
449
450 const TfLiteTensor* bw_input_to_output_weights;
451 TF_LITE_ENSURE_OK(context,
452 GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
453 &bw_input_to_output_weights));
454 const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
455 TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
456 TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
457 n_input);
458 TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->type,
459 fw_input_to_output_weights->type);
460
461 const TfLiteTensor* fw_recurrent_to_output_weights;
462 TF_LITE_ENSURE_OK(
463 context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
464 &fw_recurrent_to_output_weights));
465 TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
466 TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0],
467 n_fw_cell);
468 TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->type,
469 fw_input_to_output_weights->type);
470 const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
471
472 const TfLiteTensor* bw_recurrent_to_output_weights;
473 TF_LITE_ENSURE_OK(
474 context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
475 &bw_recurrent_to_output_weights));
476 TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
477 TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
478 n_bw_cell);
479 TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->type,
480 fw_input_to_output_weights->type);
481 const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
482
483 // Check that input tensor dimensions matches with each other.
484 TF_LITE_ENSURE_OK(
485 context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
486 n_fw_cell));
487
488 // Get (optional) auxiliary inputs and weights.
489 const TfLiteTensor* aux_input =
490 GetOptionalInputTensor(context, node, kAuxInputTensor);
491 const TfLiteTensor* fw_aux_input_to_input_weights =
492 GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
493 const TfLiteTensor* fw_aux_input_to_forget_weights =
494 GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
495 const TfLiteTensor* fw_aux_input_to_cell_weights =
496 GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
497 const TfLiteTensor* fw_aux_input_to_output_weights =
498 GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
499 const TfLiteTensor* bw_aux_input_to_input_weights =
500 GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
501 const TfLiteTensor* bw_aux_input_to_forget_weights =
502 GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
503 const TfLiteTensor* bw_aux_input_to_cell_weights =
504 GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
505 const TfLiteTensor* bw_aux_input_to_output_weights =
506 GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
507
508 const bool aux_inputs_weights_all_or_none =
509 ((fw_aux_input_to_cell_weights != nullptr) &&
510 (fw_aux_input_to_forget_weights != nullptr) &&
511 (fw_aux_input_to_output_weights != nullptr) &&
512 (bw_aux_input_to_cell_weights != nullptr) &&
513 (bw_aux_input_to_forget_weights != nullptr) &&
514 (bw_aux_input_to_output_weights != nullptr)) ||
515 ((fw_aux_input_to_cell_weights == nullptr) &&
516 (fw_aux_input_to_forget_weights == nullptr) &&
517 (fw_aux_input_to_output_weights == nullptr) &&
518 (bw_aux_input_to_cell_weights == nullptr) &&
519 (bw_aux_input_to_forget_weights == nullptr) &&
520 (bw_aux_input_to_output_weights == nullptr));
521 TF_LITE_ENSURE(context, aux_inputs_weights_all_or_none);
522
523 const bool has_aux_input = (fw_aux_input_to_forget_weights != nullptr);
524
525 if (has_aux_input) {
526 // Check that aux_input has the same dimensions (except last) as the input.
527 TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
528 TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
529 }
530
531 // Get the pointer to output, activation_state and cell_state buffer tensors.
532 TfLiteTensor* fw_output;
533 TF_LITE_ENSURE_OK(context,
534 GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
535 TfLiteTensor* fw_activation_state =
536 GetVariableInput(context, node, kFwInputActivationStateTensor);
537 TF_LITE_ENSURE(context, fw_activation_state != nullptr);
538 TfLiteTensor* fw_cell_state =
539 GetVariableInput(context, node, kFwInputCellStateTensor);
540 TF_LITE_ENSURE(context, fw_cell_state != nullptr);
541
542 // Check the shape of input state tensors.
543 // These tensor may be 1D or 2D. It's fine as long as the total size is
544 // correct.
545 TF_LITE_ENSURE_EQ(context, NumElements(fw_activation_state),
546 n_batch * n_fw_output);
547 TF_LITE_ENSURE_EQ(context, NumElements(fw_cell_state), n_batch * n_fw_cell);
548
549 // Resize the output tensors.
550 TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
551 fw_output_size->data[0] = time_major ? max_time : n_batch;
552 fw_output_size->data[1] = time_major ? n_batch : max_time;
553 fw_output_size->data[2] =
554 params->merge_outputs ? n_bw_output + n_fw_output : n_fw_output;
555 TF_LITE_ENSURE_OK(context,
556 context->ResizeTensor(context, fw_output, fw_output_size));
557
558 // The weights are of consistent type, so it suffices to check one.
559 const bool is_hybrid_op = IsHybridOp(input, fw_input_to_output_weights);
560
561 TfLiteIntArrayFree(node->temporaries);
562 if (is_hybrid_op) {
563 node->temporaries = TfLiteIntArrayCreate(
564 has_aux_input ? kNumTemporaryTensors : kNumTemporaryTensors - 1);
565 } else {
566 node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers.
567 }
568 // Create a scratch buffer tensor.
569 node->temporaries->data[kFwScratchBuffer] =
570 op_data->scratch_tensor_index + kFwScratchBuffer;
571 TfLiteTensor* fw_scratch_buffer;
572 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
573 &fw_scratch_buffer));
574 fw_scratch_buffer->type = input->type;
575 fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
576
577 const TfLiteTensor* fw_input_to_input_weights =
578 GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
579 const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
580 if (has_aux_input && !fw_use_cifg) {
581 TF_LITE_ENSURE_EQ(context, fw_aux_input_to_input_weights->dims->data[0],
582 fw_input_to_input_weights->dims->data[0]);
583 }
584 TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2);
585 fw_scratch_buffer_size->data[0] = n_batch;
586 if (fw_use_cifg) {
587 // Reserving space for Cell, Forget, Output gates and scratch accumulation
588 // buffer and an extra 16 bytes to avoid internal ruy copies.
589 fw_scratch_buffer_size->data[1] = n_fw_cell * 4 + 16;
590 } else {
591 // Reserving space for Input, Cell, Forget, Output gates and scratch
592 // accumulation buffer and an extra 16 bytes to avoid internal ruy copies.
593 fw_scratch_buffer_size->data[1] = n_fw_cell * 5 + 16;
594 }
595 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer,
596 fw_scratch_buffer_size));
597 // Same for the backward cell.
598
599 // Check that input tensor dimensions matches with each other.
600 TF_LITE_ENSURE_OK(
601 context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
602 n_bw_cell));
603
604 // Get the pointer to activation_state and cell_state buffer tensors.
605 TfLiteTensor* bw_activation_state =
606 GetVariableInput(context, node, kBwInputActivationStateTensor);
607 TF_LITE_ENSURE(context, bw_activation_state != nullptr);
608 TfLiteTensor* bw_cell_state =
609 GetVariableInput(context, node, kBwInputCellStateTensor);
610 TF_LITE_ENSURE(context, bw_cell_state != nullptr);
611
612 // Resize the output tensors.
613 if (!params->merge_outputs) {
614 TfLiteTensor* bw_output;
615 TF_LITE_ENSURE_OK(
616 context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
617 TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
618 bw_output_size->data[0] = time_major ? max_time : n_batch;
619 bw_output_size->data[1] = time_major ? n_batch : max_time;
620 bw_output_size->data[2] = n_bw_output;
621 TF_LITE_ENSURE_OK(
622 context, context->ResizeTensor(context, bw_output, bw_output_size));
623 }
624
625 // Check the shape of input state tensors.
626 // These tensor may be 1D or 2D. It's fine as long as the total size is
627 // correct.
628 TF_LITE_ENSURE_EQ(context, NumElements(bw_activation_state),
629 n_batch * n_bw_output);
630 TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell);
631
632 // Create a scratch buffer tensor.
633 node->temporaries->data[kBwScratchBuffer] =
634 op_data->scratch_tensor_index + kBwScratchBuffer;
635 TfLiteTensor* bw_scratch_buffer;
636 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
637 &bw_scratch_buffer));
638 bw_scratch_buffer->type = input->type;
639 bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
640
641 const TfLiteTensor* bw_input_to_input_weights =
642 GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
643 const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
644 if (has_aux_input && !bw_use_cifg) {
645 TF_LITE_ENSURE_EQ(context, bw_aux_input_to_input_weights->dims->data[0],
646 bw_input_to_input_weights->dims->data[0]);
647 }
648 TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2);
649 bw_scratch_buffer_size->data[0] = n_batch;
650 if (bw_use_cifg) {
651 // Reserving space for Cell, Forget, Output gates and scratch accumulation
652 // buffer and an extra 16 bytes to avoid internal ruy copies.
653 bw_scratch_buffer_size->data[1] = n_bw_cell * 4;
654 } else {
655 // Reserving space for Input, Cell, Forget, Output gates and scratch
656 // accumulation buffer and an extra 16 bytes to avoid internal ruy copies.
657 bw_scratch_buffer_size->data[1] = n_bw_cell * 5;
658 }
659 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
660 bw_scratch_buffer_size));
661 if (is_hybrid_op) {
662 // Compute the row sums for cached zero_point offset calculation.
663 op_data->compute_fw_row_sums = true;
664 op_data->compute_bw_row_sums = true;
665 // Allocate temporary tensors to store quantized values of input, aux_input
666 // (if present), activation_state and cell_state tensors.
667 node->temporaries->data[kInputQuantized] =
668 op_data->scratch_tensor_index + kInputQuantized;
669 TfLiteTensor* input_quantized;
670 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
671 &input_quantized));
672 input_quantized->type = fw_input_to_output_weights->type;
673 input_quantized->allocation_type = kTfLiteArenaRw;
674 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
675 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
676 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
677 input_quantized_size));
678 }
679
680 node->temporaries->data[kFwActivationStateQuantized] =
681 op_data->scratch_tensor_index + kFwActivationStateQuantized;
682 TfLiteTensor* fw_activation_state_quantized;
683 TF_LITE_ENSURE_OK(
684 context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
685 &fw_activation_state_quantized));
686 fw_activation_state_quantized->type = fw_input_to_output_weights->type;
687 fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
688 if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
689 fw_activation_state->dims)) {
690 TfLiteIntArray* fw_activation_state_quantized_size =
691 TfLiteIntArrayCopy(fw_activation_state->dims);
692 TF_LITE_ENSURE_OK(
693 context, context->ResizeTensor(context, fw_activation_state_quantized,
694 fw_activation_state_quantized_size));
695 }
696 node->temporaries->data[kBwActivationStateQuantized] =
697 op_data->scratch_tensor_index + kBwActivationStateQuantized;
698 TfLiteTensor* bw_activation_state_quantized;
699 TF_LITE_ENSURE_OK(
700 context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
701 &bw_activation_state_quantized));
702 bw_activation_state_quantized->type = fw_input_to_output_weights->type;
703 bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
704 if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
705 bw_activation_state->dims)) {
706 TfLiteIntArray* bw_activation_state_quantized_size =
707 TfLiteIntArrayCopy(bw_activation_state->dims);
708 TF_LITE_ENSURE_OK(
709 context, context->ResizeTensor(context, bw_activation_state_quantized,
710 bw_activation_state_quantized_size));
711 }
712 node->temporaries->data[kFwCellStateQuantized] =
713 op_data->scratch_tensor_index + kFwCellStateQuantized;
714 TfLiteTensor* fw_cell_state_quantized;
715 TF_LITE_ENSURE_OK(context,
716 GetTemporarySafe(context, node, kFwCellStateQuantized,
717 &fw_cell_state_quantized));
718 fw_cell_state_quantized->type = fw_input_to_output_weights->type;
719 fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
720 if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
721 fw_cell_state->dims)) {
722 TfLiteIntArray* fw_cell_state_quantized_size =
723 TfLiteIntArrayCopy(fw_cell_state->dims);
724 TF_LITE_ENSURE_OK(context,
725 context->ResizeTensor(context, fw_cell_state_quantized,
726 fw_cell_state_quantized_size));
727 }
728 node->temporaries->data[kBwCellStateQuantized] =
729 op_data->scratch_tensor_index + kBwCellStateQuantized;
730 TfLiteTensor* bw_cell_state_quantized;
731 TF_LITE_ENSURE_OK(context,
732 GetTemporarySafe(context, node, kBwCellStateQuantized,
733 &bw_cell_state_quantized));
734 bw_cell_state_quantized->type = fw_input_to_output_weights->type;
735 bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
736 if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
737 bw_cell_state->dims)) {
738 TfLiteIntArray* bw_cell_state_quantized_size =
739 TfLiteIntArrayCopy(bw_cell_state->dims);
740 TF_LITE_ENSURE_OK(context,
741 context->ResizeTensor(context, bw_cell_state_quantized,
742 bw_cell_state_quantized_size));
743 }
744
745 // Allocate temporary tensors to store scaling factors and product scaling
746 // factors. The latter is a convenience storage which allows to quantize
747 // a vector once (which produces the scaling factors) and multiply it with
748 // different matrices (which requires multiplying the scaling factors with
749 // the scaling factor of the matrix).
750 node->temporaries->data[kInputScalingFactors] =
751 op_data->scratch_tensor_index + kInputScalingFactors;
752 TfLiteTensor* input_sf;
753 TF_LITE_ENSURE_OK(
754 context,
755 GetTemporarySafe(context, node, kInputScalingFactors, &input_sf));
756 input_sf->type = kTfLiteFloat32;
757 input_sf->allocation_type = kTfLiteArenaRw;
758 int scaling_dims[1] = {n_batch};
759 if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) {
760 TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1);
761 input_sf_size->data[0] = n_batch;
762 TF_LITE_ENSURE_OK(
763 context, context->ResizeTensor(context, input_sf, input_sf_size));
764 }
765 node->temporaries->data[kAuxInputScalingFactors] =
766 op_data->scratch_tensor_index + kAuxInputScalingFactors;
767 TfLiteTensor* aux_input_sf;
768 TF_LITE_ENSURE_OK(context,
769 GetTemporarySafe(context, node, kAuxInputScalingFactors,
770 &aux_input_sf));
771 aux_input_sf->type = kTfLiteFloat32;
772 aux_input_sf->allocation_type = kTfLiteArenaRw;
773 if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) {
774 TfLiteIntArray* aux_input_sf_size = TfLiteIntArrayCreate(1);
775 aux_input_sf_size->data[0] = n_batch;
776 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_sf,
777 aux_input_sf_size));
778 }
779 node->temporaries->data[kOutputStateScalingFactors] =
780 op_data->scratch_tensor_index + kOutputStateScalingFactors;
781 TfLiteTensor* output_state_sf;
782 TF_LITE_ENSURE_OK(
783 context, GetTemporarySafe(context, node, kOutputStateScalingFactors,
784 &output_state_sf));
785 output_state_sf->type = kTfLiteFloat32;
786 output_state_sf->allocation_type = kTfLiteArenaRw;
787 if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) {
788 TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1);
789 output_state_sf_size->data[0] = n_batch;
790 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf,
791 output_state_sf_size));
792 }
793 node->temporaries->data[kProductScalingFactors] =
794 op_data->scratch_tensor_index + kProductScalingFactors;
795 TfLiteTensor* prod_scaling_factors;
796 TF_LITE_ENSURE_OK(context,
797 GetTemporarySafe(context, node, kProductScalingFactors,
798 &prod_scaling_factors));
799 prod_scaling_factors->type = kTfLiteFloat32;
800 prod_scaling_factors->allocation_type = kTfLiteArenaRw;
801 if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
802 scaling_dims)) {
803 TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
804 prod_scaling_factors_size->data[0] = n_batch;
805 TF_LITE_ENSURE_OK(context,
806 context->ResizeTensor(context, prod_scaling_factors,
807 prod_scaling_factors_size));
808 }
809
810 // Allocate a temporary tensor to store the recovered cell weights. Since
811 // this is used for diagonal matrices, only need to store n_cell values.
812 node->temporaries->data[kRecoveredCellWeights] =
813 op_data->scratch_tensor_index + kRecoveredCellWeights;
814 TfLiteTensor* recovered_cell_weights;
815 TF_LITE_ENSURE_OK(context,
816 GetTemporarySafe(context, node, kRecoveredCellWeights,
817 &recovered_cell_weights));
818 recovered_cell_weights->type = kTfLiteFloat32;
819 recovered_cell_weights->allocation_type = kTfLiteArenaRw;
820 int recovered_cell_dims[1] = {n_fw_cell};
821 if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1,
822 recovered_cell_dims)) {
823 TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
824 recovered_cell_weights_size->data[0] = n_fw_cell;
825 TF_LITE_ENSURE_OK(context,
826 context->ResizeTensor(context, recovered_cell_weights,
827 recovered_cell_weights_size));
828 }
829
830 // Allocate a temporary tensor to store the accumulated int32 values.
831 node->temporaries->data[kAccumScratchBuffer] =
832 op_data->scratch_tensor_index + kAccumScratchBuffer;
833 TfLiteTensor* accum_scratch;
834 TF_LITE_ENSURE_OK(
835 context,
836 GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
837 accum_scratch->type = kTfLiteInt32;
838 accum_scratch->allocation_type = kTfLiteArenaRw;
839 int n_cell = std::max(n_fw_cell, n_bw_cell);
840 if (has_aux_input) {
841 n_cell = std::max(n_cell, fw_aux_input_to_output_weights->dims->data[0]);
842 n_cell = std::max(n_cell, bw_aux_input_to_output_weights->dims->data[0]);
843 }
844 int accum_scratch_dims[2] = {n_cell, n_batch};
845 if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
846 accum_scratch_dims)) {
847 TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
848 accum_size->data[0] = n_cell;
849 accum_size->data[1] = n_batch;
850 TF_LITE_ENSURE_OK(
851 context, context->ResizeTensor(context, accum_scratch, accum_size));
852 }
853
854 // Allocate temporary tensors for storing zero-points.
855 node->temporaries->data[kInputZeroPoints] =
856 op_data->scratch_tensor_index + kInputZeroPoints;
857 TfLiteTensor* input_zp;
858 TF_LITE_ENSURE_OK(
859 context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp));
860 input_zp->type = kTfLiteFloat32;
861 input_zp->allocation_type = kTfLiteArenaRw;
862 if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) {
863 TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1);
864 input_zp_size->data[0] = n_batch;
865 TF_LITE_ENSURE_OK(
866 context, context->ResizeTensor(context, input_zp, input_zp_size));
867 }
868 node->temporaries->data[kAuxInputZeroPoints] =
869 op_data->scratch_tensor_index + kAuxInputZeroPoints;
870 TfLiteTensor* aux_input_zp;
871 TF_LITE_ENSURE_OK(
872 context,
873 GetTemporarySafe(context, node, kAuxInputZeroPoints, &aux_input_zp));
874 aux_input_zp->type = kTfLiteFloat32;
875 aux_input_zp->allocation_type = kTfLiteArenaRw;
876 if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) {
877 TfLiteIntArray* aux_input_zp_size = TfLiteIntArrayCreate(1);
878 aux_input_zp_size->data[0] = n_batch;
879 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_zp,
880 aux_input_zp_size));
881 }
882 node->temporaries->data[kOutputStateZeroPoints] =
883 op_data->scratch_tensor_index + kOutputStateZeroPoints;
884 TfLiteTensor* output_state_zp;
885 TF_LITE_ENSURE_OK(context,
886 GetTemporarySafe(context, node, kOutputStateZeroPoints,
887 &output_state_zp));
888 output_state_zp->type = kTfLiteFloat32;
889 output_state_zp->allocation_type = kTfLiteArenaRw;
890 if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) {
891 TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1);
892 output_state_zp_size->data[0] = n_batch;
893 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp,
894 output_state_zp_size));
895 }
896
897 // Allocate temporary tensors for caching row sums for hybrid zero-point
898 // calculations.
899 int fw_row_sums_rows = fw_use_cifg ? 6 : 8;
900 if (has_aux_input) {
901 fw_row_sums_rows += fw_use_cifg ? 3 : 4;
902 }
903 const TfLiteTensor* fw_projection_weights =
904 GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
905 if (fw_projection_weights != nullptr) {
906 fw_row_sums_rows += ceil(static_cast<float>(n_fw_output) / n_fw_cell);
907 }
908 node->temporaries->data[kFwRowSums] =
909 op_data->scratch_tensor_index + kFwRowSums;
910 TfLiteTensor* fw_row_sums;
911 TF_LITE_ENSURE_OK(
912 context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
913 fw_row_sums->type = kTfLiteInt32;
914 fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
915 int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell};
916 if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
917 TfLiteIntArray* fw_hybrid_scratch_size = TfLiteIntArrayCreate(2);
918 fw_hybrid_scratch_size->data[0] = fw_row_sums_dims[0];
919 fw_hybrid_scratch_size->data[1] = fw_row_sums_dims[1];
920 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
921 fw_hybrid_scratch_size));
922 }
923
924 int bw_row_sums_rows = bw_use_cifg ? 6 : 8;
925 if (has_aux_input) {
926 bw_row_sums_rows += bw_use_cifg ? 3 : 4;
927 }
928 const TfLiteTensor* bw_projection_weights =
929 GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
930 if (bw_projection_weights != nullptr) {
931 bw_row_sums_rows += ceil(static_cast<float>(n_bw_output) / n_bw_cell);
932 }
933 node->temporaries->data[kBwRowSums] =
934 op_data->scratch_tensor_index + kBwRowSums;
935 TfLiteTensor* bw_row_sums;
936 TF_LITE_ENSURE_OK(
937 context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
938 bw_row_sums->type = kTfLiteInt32;
939 bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
940 int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell};
941 if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
942 TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
943 bw_row_sums_size->data[0] = bw_row_sums_dims[0];
944 bw_row_sums_size->data[1] = bw_row_sums_dims[1];
945 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
946 bw_row_sums_size));
947 }
948
949 // Only allocate a temporary tensor for quantized auxiliary input if we are
950 // actually going to use it.
951 if (has_aux_input) {
952 node->temporaries->data[kAuxInputQuantized] =
953 op_data->scratch_tensor_index + kAuxInputQuantized;
954 TfLiteTensor* aux_input_quantized;
955 TF_LITE_ENSURE_OK(context,
956 GetTemporarySafe(context, node, kAuxInputQuantized,
957 &aux_input_quantized));
958 aux_input_quantized->type = fw_input_to_output_weights->type;
959 aux_input_quantized->allocation_type = kTfLiteArenaRw;
960 if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
961 TfLiteIntArray* aux_input_quantized_size =
962 TfLiteIntArrayCopy(aux_input->dims);
963 TF_LITE_ENSURE_OK(context,
964 context->ResizeTensor(context, aux_input_quantized,
965 aux_input_quantized_size));
966 }
967 }
968 }
969 return kTfLiteOk;
970}
971
972// The LSTM Op engine.
973TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
974 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
975 node->builtin_data);
976 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
977 // Input tensor.
978 const TfLiteTensor* input;
979 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
980
981 // Tensors for the forward cell.
982 const TfLiteTensor* fw_input_to_input_weights =
983 GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
984 const TfLiteTensor* fw_input_to_forget_weights;
985 TF_LITE_ENSURE_OK(context,
986 GetInputSafe(context, node, kFwInputToForgetWeightsTensor,
987 &fw_input_to_forget_weights));
988 const TfLiteTensor* fw_input_to_cell_weights;
989 TF_LITE_ENSURE_OK(context,
990 GetInputSafe(context, node, kFwInputToCellWeightsTensor,
991 &fw_input_to_cell_weights));
992 const TfLiteTensor* fw_input_to_output_weights;
993 TF_LITE_ENSURE_OK(context,
994 GetInputSafe(context, node, kFwInputToOutputWeightsTensor,
995 &fw_input_to_output_weights));
996
997 const TfLiteTensor* fw_recurrent_to_input_weights =
998 GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor);
999 const TfLiteTensor* fw_recurrent_to_forget_weights;
1000 TF_LITE_ENSURE_OK(
1001 context, GetInputSafe(context, node, kFwRecurrentToForgetWeightsTensor,
1002 &fw_recurrent_to_forget_weights));
1003 const TfLiteTensor* fw_recurrent_to_cell_weights;
1004 TF_LITE_ENSURE_OK(context,
1005 GetInputSafe(context, node, kFwRecurrentToCellWeightsTensor,
1006 &fw_recurrent_to_cell_weights));
1007 const TfLiteTensor* fw_recurrent_to_output_weights;
1008 TF_LITE_ENSURE_OK(
1009 context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor,
1010 &fw_recurrent_to_output_weights));
1011
1012 const TfLiteTensor* fw_cell_to_input_weights =
1013 GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor);
1014 const TfLiteTensor* fw_cell_to_forget_weights =
1015 GetOptionalInputTensor(context, node, kFwCellToForgetWeightsTensor);
1016 const TfLiteTensor* fw_cell_to_output_weights =
1017 GetOptionalInputTensor(context, node, kFwCellToOutputWeightsTensor);
1018
1019 const TfLiteTensor* fw_input_gate_bias =
1020 GetOptionalInputTensor(context, node, kFwInputGateBiasTensor);
1021 const TfLiteTensor* fw_forget_gate_bias;
1022 TF_LITE_ENSURE_OK(context,
1023 GetInputSafe(context, node, kFwForgetGateBiasTensor,
1024 &fw_forget_gate_bias));
1025 const TfLiteTensor* fw_cell_gate_bias;
1026 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwCellGateBiasTensor,
1027 &fw_cell_gate_bias));
1028 const TfLiteTensor* fw_output_gate_bias;
1029 TF_LITE_ENSURE_OK(context,
1030 GetInputSafe(context, node, kFwOutputGateBiasTensor,
1031 &fw_output_gate_bias));
1032
1033 const TfLiteTensor* fw_projection_weights =
1034 GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor);
1035 const TfLiteTensor* fw_projection_bias =
1036 GetOptionalInputTensor(context, node, kFwProjectionBiasTensor);
1037
1038 TfLiteTensor* fw_activation_state =
1039 GetVariableInput(context, node, kFwInputActivationStateTensor);
1040 TFLITE_DCHECK(fw_activation_state != nullptr);
1041 TfLiteTensor* fw_cell_state =
1042 GetVariableInput(context, node, kFwInputCellStateTensor);
1043 TFLITE_DCHECK(fw_cell_state != nullptr);
1044 TfLiteTensor* fw_output;
1045 TF_LITE_ENSURE_OK(context,
1046 GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
1047
1048 // Tensors for the backward cell.
1049 const TfLiteTensor* bw_input_to_input_weights =
1050 GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
1051 const TfLiteTensor* bw_input_to_forget_weights;
1052 TF_LITE_ENSURE_OK(context,
1053 GetInputSafe(context, node, kBwInputToForgetWeightsTensor,
1054 &bw_input_to_forget_weights));
1055 const TfLiteTensor* bw_input_to_cell_weights;
1056 TF_LITE_ENSURE_OK(context,
1057 GetInputSafe(context, node, kBwInputToCellWeightsTensor,
1058 &bw_input_to_cell_weights));
1059 const TfLiteTensor* bw_input_to_output_weights;
1060 TF_LITE_ENSURE_OK(context,
1061 GetInputSafe(context, node, kBwInputToOutputWeightsTensor,
1062 &bw_input_to_output_weights));
1063
1064 const TfLiteTensor* bw_recurrent_to_input_weights =
1065 GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor);
1066 const TfLiteTensor* bw_recurrent_to_forget_weights;
1067 TF_LITE_ENSURE_OK(
1068 context, GetInputSafe(context, node, kBwRecurrentToForgetWeightsTensor,
1069 &bw_recurrent_to_forget_weights));
1070 const TfLiteTensor* bw_recurrent_to_cell_weights;
1071 TF_LITE_ENSURE_OK(context,
1072 GetInputSafe(context, node, kBwRecurrentToCellWeightsTensor,
1073 &bw_recurrent_to_cell_weights));
1074 const TfLiteTensor* bw_recurrent_to_output_weights;
1075 TF_LITE_ENSURE_OK(
1076 context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor,
1077 &bw_recurrent_to_output_weights));
1078
1079 const TfLiteTensor* bw_cell_to_input_weights =
1080 GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor);
1081 const TfLiteTensor* bw_cell_to_forget_weights =
1082 GetOptionalInputTensor(context, node, kBwCellToForgetWeightsTensor);
1083 const TfLiteTensor* bw_cell_to_output_weights =
1084 GetOptionalInputTensor(context, node, kBwCellToOutputWeightsTensor);
1085
1086 const TfLiteTensor* bw_input_gate_bias =
1087 GetOptionalInputTensor(context, node, kBwInputGateBiasTensor);
1088 const TfLiteTensor* bw_forget_gate_bias;
1089 TF_LITE_ENSURE_OK(context,
1090 GetInputSafe(context, node, kBwForgetGateBiasTensor,
1091 &bw_forget_gate_bias));
1092 const TfLiteTensor* bw_cell_gate_bias;
1093 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwCellGateBiasTensor,
1094 &bw_cell_gate_bias));
1095 const TfLiteTensor* bw_output_gate_bias;
1096 TF_LITE_ENSURE_OK(context,
1097 GetInputSafe(context, node, kBwOutputGateBiasTensor,
1098 &bw_output_gate_bias));
1099
1100 const TfLiteTensor* bw_projection_weights =
1101 GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor);
1102 const TfLiteTensor* bw_projection_bias =
1103 GetOptionalInputTensor(context, node, kBwProjectionBiasTensor);
1104
1105 // State tensors.
1106 TfLiteTensor* bw_activation_state =
1107 GetVariableInput(context, node, kBwInputActivationStateTensor);
1108 TFLITE_DCHECK(bw_activation_state != nullptr);
1109 TfLiteTensor* bw_cell_state =
1110 GetVariableInput(context, node, kBwInputCellStateTensor);
1111 TFLITE_DCHECK(bw_cell_state != nullptr);
1112 TfLiteTensor* bw_output = params->merge_outputs
1113 ? nullptr
1114 : GetOutput(context, node, kBwOutputTensor);
1115
1116 // Temporary tensors.
1117 TfLiteTensor* fw_scratch_buffer;
1118 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer,
1119 &fw_scratch_buffer));
1120 TfLiteTensor* bw_scratch_buffer;
1121 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer,
1122 &bw_scratch_buffer));
1123
1124 // (Optional) auxiliary inputs.
1125 const TfLiteTensor* aux_input =
1126 GetOptionalInputTensor(context, node, kAuxInputTensor);
1127 const TfLiteTensor* fw_aux_input_to_input_weights =
1128 GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
1129 const TfLiteTensor* fw_aux_input_to_forget_weights =
1130 GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
1131 const TfLiteTensor* fw_aux_input_to_cell_weights =
1132 GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
1133 const TfLiteTensor* fw_aux_input_to_output_weights =
1134 GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
1135 const TfLiteTensor* bw_aux_input_to_input_weights =
1136 GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
1137 const TfLiteTensor* bw_aux_input_to_forget_weights =
1138 GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
1139 const TfLiteTensor* bw_aux_input_to_cell_weights =
1140 GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
1141 const TfLiteTensor* bw_aux_input_to_output_weights =
1142 GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
1143
1144 const bool has_previous_bw_output = (aux_input != nullptr);
1145 const bool use_aux_input = (fw_aux_input_to_forget_weights != nullptr);
1146
1147 // Populate a TfLiteLSTMParams struct for the evaluation functions.
1148 TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
1149 params->proj_clip, kTfLiteLSTMFullKernel,
1150 params->asymmetric_quantize_inputs};
1151
1152 const int bw_output_offset =
1153 params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
1154 const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
1155
1156 const bool time_major = params->time_major;
1157
1158 // We want to cover the following cases:
1159 //
1160 // If not stacking (not connected after other bidi lstms):
1161 // both fw & bw will just use `input`; aux_input will be null.
1162 //
1163 // If stacking with cross_links, TensorFlow equivalent
1164 // (tf.contrib.rnn.stack_bidirectional_rnn):
1165 // both fw & bw will use `input`, but aux_input will be none null.
1166 // Note, this time, whether connected after other bidi lstms both works.
1167 //
1168 // If stacking without cross_links, but connected after other bidi lstms,
1169 // TensorFlow equivalent (tf.nn.static_bidirectional_rnn):
1170 // fw will use `input`, bw will use aux_input, and the `real aux_input`
1171 // will be null.
1172
1173 const bool non_stacking_mode = !use_aux_input && has_previous_bw_output;
1174 const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input;
1175 const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input;
1176
1177 switch (fw_input_to_output_weights->type) {
1178 case kTfLiteFloat32: {
1179 TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
1180 input, fw_input_to_input_weights, fw_input_to_forget_weights,
1181 fw_input_to_cell_weights, fw_input_to_output_weights,
1182 fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
1183 fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
1184 fw_cell_to_input_weights, fw_cell_to_forget_weights,
1185 fw_cell_to_output_weights,
1186 /*input_layer_norm_coefficients=*/nullptr,
1187 /*forget_layer_norm_coefficients=*/nullptr,
1188 /*cell_layer_norm_coefficients=*/nullptr,
1189 /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1190 fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
1191 fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
1192 fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
1193 fw_output_gate_bias, fw_projection_weights, fw_projection_bias,
1194 &lstm_params,
1195 /*forward_sequence=*/true, time_major, /*output_offset=*/0,
1196 fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output,
1197 CpuBackendContext::GetFromContext(context));
1198 TF_LITE_ENSURE_OK(context, fw_pass_status);
1199
1200 TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
1201 bw_input, bw_input_to_input_weights, bw_input_to_forget_weights,
1202 bw_input_to_cell_weights, bw_input_to_output_weights,
1203 bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
1204 bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
1205 bw_cell_to_input_weights, bw_cell_to_forget_weights,
1206 bw_cell_to_output_weights,
1207 /*input_layer_norm_coefficients=*/nullptr,
1208 /*forget_layer_norm_coefficients=*/nullptr,
1209 /*cell_layer_norm_coefficients=*/nullptr,
1210 /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1211 bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
1212 bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
1213 bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
1214 bw_output_gate_bias, bw_projection_weights, bw_projection_bias,
1215 &lstm_params,
1216 /*forward_sequence=*/false, time_major, bw_output_offset,
1217 bw_scratch_buffer, bw_activation_state, bw_cell_state,
1218 actual_bw_output, CpuBackendContext::GetFromContext(context));
1219 TF_LITE_ENSURE_OK(context, bw_pass_status);
1220 return kTfLiteOk;
1221 }
1222 case kTfLiteUInt8:
1223 case kTfLiteInt8: {
1224 TfLiteTensor* input_quantized;
1225 TF_LITE_ENSURE_OK(
1226 context,
1227 GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
1228 TfLiteTensor* fw_activation_state_quantized;
1229 TF_LITE_ENSURE_OK(
1230 context, GetTemporarySafe(context, node, kFwActivationStateQuantized,
1231 &fw_activation_state_quantized));
1232 TfLiteTensor* bw_activation_state_quantized;
1233 TF_LITE_ENSURE_OK(
1234 context, GetTemporarySafe(context, node, kBwActivationStateQuantized,
1235 &bw_activation_state_quantized));
1236 TfLiteTensor* fw_cell_state_quantized;
1237 TF_LITE_ENSURE_OK(context,
1238 GetTemporarySafe(context, node, kFwCellStateQuantized,
1239 &fw_cell_state_quantized));
1240 TfLiteTensor* bw_cell_state_quantized;
1241 TF_LITE_ENSURE_OK(context,
1242 GetTemporarySafe(context, node, kBwCellStateQuantized,
1243 &bw_cell_state_quantized));
1244 TfLiteTensor* prod_scaling_factors;
1245 TF_LITE_ENSURE_OK(context,
1246 GetTemporarySafe(context, node, kProductScalingFactors,
1247 &prod_scaling_factors));
1248 TfLiteTensor* recovered_cell_weights;
1249 TF_LITE_ENSURE_OK(context,
1250 GetTemporarySafe(context, node, kRecoveredCellWeights,
1251 &recovered_cell_weights));
1252 TfLiteTensor* aux_input_quantized =
1253 use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
1254 : nullptr;
1255 TfLiteTensor* accum_scratch;
1256 TF_LITE_ENSURE_OK(
1257 context,
1258 GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch));
1259 TfLiteTensor* fw_row_sums;
1260 TF_LITE_ENSURE_OK(
1261 context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
1262 TfLiteTensor* bw_row_sums;
1263 TF_LITE_ENSURE_OK(
1264 context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
1265 const int fw_row_sums_size = fw_row_sums->dims->data[0];
1266 const int bw_row_sums_size = bw_row_sums->dims->data[0];
1267 TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
1268 input, fw_input_to_input_weights,
1269 /*input_to_input_weights_ledger*/ nullptr, fw_input_to_forget_weights,
1270 /*input_to_forget_weights_ledger*/ nullptr, fw_input_to_cell_weights,
1271 /*input_to_cell_weights_ledger*/ nullptr, fw_input_to_output_weights,
1272 /*input_to_output_weights_ledger*/ nullptr,
1273 fw_recurrent_to_input_weights,
1274 /*recurrent_to_input_weights_ledger*/ nullptr,
1275 fw_recurrent_to_forget_weights,
1276 /*recurrent_to_forget_weights_ledger*/ nullptr,
1277 fw_recurrent_to_cell_weights,
1278 /*recurrent_to_cell_weights_ledger*/ nullptr,
1279 fw_recurrent_to_output_weights,
1280 /*recurrent_to_output_weights_ledger*/ nullptr,
1281 fw_cell_to_input_weights, fw_cell_to_forget_weights,
1282 fw_cell_to_output_weights,
1283 /*input_layer_norm_coefficients=*/nullptr,
1284 /*forget_layer_norm_coefficients=*/nullptr,
1285 /*cell_layer_norm_coefficients=*/nullptr,
1286 /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1287 fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights,
1288 fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights,
1289 fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias,
1290 fw_output_gate_bias, fw_projection_weights,
1291 /*projection_weights_ledger*/ nullptr, fw_projection_bias,
1292 &lstm_params,
1293 /*forward_sequence=*/true, time_major, /*output_offset=*/0,
1294 fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
1295 GetTemporary(context, node, kAuxInputScalingFactors),
1296 GetTemporary(context, node, kOutputStateScalingFactors),
1297 prod_scaling_factors, recovered_cell_weights, input_quantized,
1298 aux_input_quantized, fw_activation_state_quantized,
1299 fw_cell_state_quantized, fw_activation_state, fw_cell_state,
1300 accum_scratch, fw_output,
1301 GetTemporary(context, node, kInputZeroPoints),
1302 GetTemporary(context, node, kAuxInputZeroPoints),
1303 GetTemporary(context, node, kOutputStateZeroPoints), fw_row_sums,
1304 fw_row_sums_size, &op_data->compute_fw_row_sums,
1305 CpuBackendContext::GetFromContext(context));
1306 TF_LITE_ENSURE_OK(context, fw_pass_status);
1307
1308 TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
1309 bw_input, bw_input_to_input_weights,
1310 /*input_to_input_weights_ledger*/ nullptr, bw_input_to_forget_weights,
1311 /*input_to_forget_weights_ledger*/ nullptr, bw_input_to_cell_weights,
1312 /*input_to_cell_weights_ledger*/ nullptr, bw_input_to_output_weights,
1313 /*input_to_output_weights_ledger*/ nullptr,
1314 bw_recurrent_to_input_weights,
1315 /*recurrent_to_input_weights_ledger*/ nullptr,
1316 bw_recurrent_to_forget_weights,
1317 /*recurrent_to_forget_weights_ledger*/ nullptr,
1318 bw_recurrent_to_cell_weights,
1319 /*recurrent_to_cell_weights_ledger*/ nullptr,
1320 bw_recurrent_to_output_weights,
1321 /*recurrent_to_output_weights_ledger*/ nullptr,
1322 bw_cell_to_input_weights, bw_cell_to_forget_weights,
1323 bw_cell_to_output_weights,
1324 /*input_layer_norm_coefficients=*/nullptr,
1325 /*forget_layer_norm_coefficients=*/nullptr,
1326 /*cell_layer_norm_coefficients=*/nullptr,
1327 /*output_layer_norm_coefficients=*/nullptr, real_aux_input,
1328 bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights,
1329 bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights,
1330 bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias,
1331 bw_output_gate_bias, bw_projection_weights,
1332 /*projection_weights_ledger*/ nullptr, bw_projection_bias,
1333 &lstm_params,
1334 /*forward_sequence=*/false, time_major, bw_output_offset,
1335 bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors),
1336 GetTemporary(context, node, kAuxInputScalingFactors),
1337 GetTemporary(context, node, kOutputStateScalingFactors),
1338 prod_scaling_factors, recovered_cell_weights, input_quantized,
1339 aux_input_quantized, bw_activation_state_quantized,
1340 bw_cell_state_quantized, bw_activation_state, bw_cell_state,
1341 accum_scratch, actual_bw_output,
1342 GetTemporary(context, node, kInputZeroPoints),
1343 GetTemporary(context, node, kAuxInputZeroPoints),
1344 GetTemporary(context, node, kOutputStateZeroPoints), bw_row_sums,
1345 bw_row_sums_size, &op_data->compute_bw_row_sums,
1346 CpuBackendContext::GetFromContext(context));
1347 TF_LITE_ENSURE_OK(context, bw_pass_status);
1348 return kTfLiteOk;
1349 }
1350 default:
1351 TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported.",
1352 TfLiteTypeGetName(fw_input_to_output_weights->type));
1353 return kTfLiteError;
1354 }
1355}
1356
1357} // namespace bidirectional_sequence_lstm
1358
1359TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() {
1360 static TfLiteRegistration r = {
1361 bidirectional_sequence_lstm::Init, bidirectional_sequence_lstm::Free,
1362 bidirectional_sequence_lstm::Prepare, bidirectional_sequence_lstm::Eval};
1363 return &r;
1364}
1365
1366} // namespace builtin
1367} // namespace ops
1368} // namespace tflite
1369