1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <math.h> |
17 | |
18 | #include <algorithm> |
19 | #include <cstddef> |
20 | |
21 | #include "tensorflow/lite/c/builtin_op_data.h" |
22 | #include "tensorflow/lite/c/common.h" |
23 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
24 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
25 | #include "tensorflow/lite/kernels/internal/kernel_utils.h" |
26 | #include "tensorflow/lite/kernels/internal/tensor_utils.h" |
27 | #include "tensorflow/lite/kernels/kernel_util.h" |
28 | #include "tensorflow/lite/kernels/lstm_eval.h" |
29 | #include "tensorflow/lite/kernels/op_macros.h" |
30 | |
31 | namespace tflite { |
32 | namespace ops { |
33 | namespace builtin { |
34 | namespace bidirectional_sequence_lstm { |
35 | |
36 | // LINT.IfChange |
37 | |
38 | // Input Tensors of size {max_time, n_batch, n_input} |
39 | constexpr int kInputTensor = 0; |
40 | |
41 | // Forward LSTM cell tensors. |
42 | // Input weight tensors of size: {n_cell, n_input} |
43 | constexpr int kFwInputToInputWeightsTensor = 1; // Optional |
44 | constexpr int kFwInputToForgetWeightsTensor = 2; |
45 | constexpr int kFwInputToCellWeightsTensor = 3; |
46 | constexpr int kFwInputToOutputWeightsTensor = 4; |
47 | |
48 | // Recurrent weight tensors of size {n_cell, n_output} |
49 | constexpr int kFwRecurrentToInputWeightsTensor = 5; // Optional |
50 | constexpr int kFwRecurrentToForgetWeightsTensor = 6; |
51 | constexpr int kFwRecurrentToCellWeightsTensor = 7; |
52 | constexpr int kFwRecurrentToOutputWeightsTensor = 8; |
53 | |
54 | // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. |
55 | constexpr int kFwCellToInputWeightsTensor = 9; // Optional |
56 | constexpr int kFwCellToForgetWeightsTensor = 10; // Optional |
57 | constexpr int kFwCellToOutputWeightsTensor = 11; // Optional |
58 | |
59 | // Gates bias tensors of size {n_cell} |
60 | constexpr int kFwInputGateBiasTensor = 12; // Optional |
61 | constexpr int kFwForgetGateBiasTensor = 13; |
62 | constexpr int kFwCellGateBiasTensor = 14; |
63 | constexpr int kFwOutputGateBiasTensor = 15; |
64 | |
65 | // Projection weight tensor of size {n_output, n_cell} |
66 | constexpr int kFwProjectionWeightsTensor = 16; // Optional |
67 | // Projection bias tensor of size {n_output} |
68 | constexpr int kFwProjectionBiasTensor = 17; // Optional |
69 | |
70 | // Backward LSTM cell tensors. |
71 | // Input weight tensors of size: {n_cell, n_input} |
72 | constexpr int kBwInputToInputWeightsTensor = 18; // Optional |
73 | constexpr int kBwInputToForgetWeightsTensor = 19; |
74 | constexpr int kBwInputToCellWeightsTensor = 20; |
75 | constexpr int kBwInputToOutputWeightsTensor = 21; |
76 | |
77 | // Recurrent weight tensors of size {n_cell, n_output} |
78 | constexpr int kBwRecurrentToInputWeightsTensor = 22; // Optional |
79 | constexpr int kBwRecurrentToForgetWeightsTensor = 23; |
80 | constexpr int kBwRecurrentToCellWeightsTensor = 24; |
81 | constexpr int kBwRecurrentToOutputWeightsTensor = 25; |
82 | |
83 | // Peephole weights tensors of size {n_cell}, representing a diagonal matrix. |
84 | constexpr int kBwCellToInputWeightsTensor = 26; // Optional |
85 | constexpr int kBwCellToForgetWeightsTensor = 27; // Optional |
86 | constexpr int kBwCellToOutputWeightsTensor = 28; // Optional |
87 | |
88 | // Gates bias tensors of size {n_cell} |
89 | constexpr int kBwInputGateBiasTensor = 29; // Optional |
90 | constexpr int kBwForgetGateBiasTensor = 30; |
91 | constexpr int kBwCellGateBiasTensor = 31; |
92 | constexpr int kBwOutputGateBiasTensor = 32; |
93 | |
94 | // Projection weight tensor of size {n_output, n_cell} |
95 | constexpr int kBwProjectionWeightsTensor = 33; // Optional |
96 | // Projection bias tensor of size {n_output} |
97 | constexpr int kBwProjectionBiasTensor = 34; // Optional |
98 | |
99 | // Stateful input tensors that are variables and will be modified by the Op. |
100 | // Activation state tensors of size {n_batch, n_output} |
101 | constexpr int kFwInputActivationStateTensor = 35; |
102 | // Cell state tensors of size {n_batch, n_cell} |
103 | constexpr int kFwInputCellStateTensor = 36; |
104 | // Activation state tensors of size {n_batch, n_output} |
105 | constexpr int kBwInputActivationStateTensor = 37; |
106 | // Cell state tensors of size {n_batch, n_cell} |
107 | constexpr int kBwInputCellStateTensor = 38; |
108 | |
109 | // Used as auxiliary input and weights when stacking for |
110 | // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input |
111 | // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case |
112 | // (without cross links). |
113 | constexpr int kAuxInputTensor = 39; // Optional |
114 | // Forward weights. |
115 | constexpr int kFwAuxInputToInputWeightsTensor = 40; // Optional |
116 | constexpr int kFwAuxInputToForgetWeightsTensor = 41; // Optional |
117 | constexpr int kFwAuxInputToCellWeightsTensor = 42; // Optional |
118 | constexpr int kFwAuxInputToOutputWeightsTensor = 43; // Optional |
119 | // Backward weights. |
120 | constexpr int kBwAuxInputToInputWeightsTensor = 44; // Optional |
121 | constexpr int kBwAuxInputToForgetWeightsTensor = 45; // Optional |
122 | constexpr int kBwAuxInputToCellWeightsTensor = 46; // Optional |
123 | constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional |
124 | |
125 | // Output tensors. |
126 | constexpr int kFwOutputTensor = 0; |
127 | constexpr int kBwOutputTensor = 1; // Ignored if merge_outputs is set. |
128 | |
129 | // LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc) |
130 | |
131 | // Temporary tensors. |
132 | enum TemporaryTensor { |
133 | // Scratch buffers for input, forget, etc. gates |
134 | kFwScratchBuffer = 0, |
135 | kBwScratchBuffer = 1, |
136 | // Quantized tensors needed for the hybrid kernel. |
137 | kInputQuantized = 2, |
138 | kFwActivationStateQuantized = 3, |
139 | kBwActivationStateQuantized = 4, |
140 | kFwCellStateQuantized = 5, |
141 | kBwCellStateQuantized = 6, |
142 | kInputScalingFactors = 7, |
143 | kAuxInputScalingFactors = 8, |
144 | kOutputStateScalingFactors = 9, |
145 | kProductScalingFactors = 10, |
146 | kRecoveredCellWeights = 11, |
147 | kAccumScratchBuffer = 12, |
148 | kInputZeroPoints = 13, |
149 | kAuxInputZeroPoints = 14, |
150 | kOutputStateZeroPoints = 15, |
151 | kFwRowSums = 16, |
152 | kBwRowSums = 17, |
153 | kAuxInputQuantized = 18, // Optional, quantized tensor for auxiliary input. |
154 | kNumTemporaryTensors = 19, |
155 | }; |
156 | |
157 | struct OpData { |
158 | int scratch_tensor_index; |
159 | bool compute_fw_row_sums = false; |
160 | bool compute_bw_row_sums = false; |
161 | }; |
162 | |
163 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
164 | auto* op_data = new OpData(); |
165 | context->AddTensors(context, kNumTemporaryTensors, |
166 | &op_data->scratch_tensor_index); |
167 | return op_data; |
168 | } |
169 | |
170 | void Free(TfLiteContext* context, void* buffer) { |
171 | delete reinterpret_cast<OpData*>(buffer); |
172 | } |
173 | |
174 | // Check that input tensor dimensions matches with each other. |
175 | TfLiteStatus CheckLstmTensorDimensionsAndTypes( |
176 | TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, |
177 | int n_cell, int input_to_input_weights_tensor, |
178 | int input_to_forget_weights_tensor, int input_to_cell_weights_tensor, |
179 | int input_to_output_weights_tensor, int recurrent_to_input_weights_tensor, |
180 | int recurrent_to_forget_weights_tensor, |
181 | int recurrent_to_cell_weights_tensor, |
182 | int recurrent_to_output_weights_tensor, int cell_to_input_weights_tensor, |
183 | int cell_to_forget_weights_tensor, int cell_to_output_weights_tensor, |
184 | int input_gate_bias_tensor, int forget_gate_bias_tensor, |
185 | int cell_gate_bias_tensor, int output_gate_bias_tensor, |
186 | int projection_weights_tensor, int projection_bias_tensor) { |
187 | const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( |
188 | node->builtin_data); |
189 | |
190 | // Making sure clipping parameters have valid values. |
191 | // == 0 means no clipping |
192 | // > 0 means clipping |
193 | TF_LITE_ENSURE(context, params->cell_clip >= 0); |
194 | TF_LITE_ENSURE(context, params->proj_clip >= 0); |
195 | |
196 | const TfLiteTensor* input_to_forget_weights; |
197 | TF_LITE_ENSURE_OK(context, |
198 | GetInputSafe(context, node, input_to_forget_weights_tensor, |
199 | &input_to_forget_weights)); |
200 | TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); |
201 | TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); |
202 | TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); |
203 | TF_LITE_ENSURE(context, (input_to_forget_weights->type == kTfLiteFloat32) || |
204 | (input_to_forget_weights->type == kTfLiteInt8) || |
205 | (input_to_forget_weights->type == kTfLiteUInt8)); |
206 | |
207 | const TfLiteTensor* input_to_input_weights = |
208 | GetOptionalInputTensor(context, node, input_to_input_weights_tensor); |
209 | if (input_to_input_weights != nullptr) { |
210 | TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); |
211 | TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); |
212 | TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); |
213 | TF_LITE_ENSURE_TYPES_EQ(context, input_to_input_weights->type, |
214 | input_to_forget_weights->type); |
215 | } |
216 | |
217 | const TfLiteTensor* input_to_cell_weights; |
218 | TF_LITE_ENSURE_OK(context, |
219 | GetInputSafe(context, node, input_to_cell_weights_tensor, |
220 | &input_to_cell_weights)); |
221 | TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); |
222 | TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); |
223 | TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); |
224 | TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type, |
225 | input_to_forget_weights->type); |
226 | |
227 | const TfLiteTensor* input_to_output_weights; |
228 | TF_LITE_ENSURE_OK(context, |
229 | GetInputSafe(context, node, input_to_output_weights_tensor, |
230 | &input_to_output_weights)); |
231 | TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); |
232 | TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[0], n_cell); |
233 | TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); |
234 | TF_LITE_ENSURE_TYPES_EQ(context, input_to_output_weights->type, |
235 | input_to_forget_weights->type); |
236 | |
237 | const TfLiteTensor* recurrent_to_input_weights = |
238 | GetOptionalInputTensor(context, node, recurrent_to_input_weights_tensor); |
239 | if (recurrent_to_input_weights != nullptr) { |
240 | TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); |
241 | TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], |
242 | n_cell); |
243 | TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], |
244 | n_output); |
245 | TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_input_weights->type, |
246 | input_to_forget_weights->type); |
247 | } |
248 | |
249 | const TfLiteTensor* recurrent_to_forget_weights; |
250 | TF_LITE_ENSURE_OK( |
251 | context, GetInputSafe(context, node, recurrent_to_forget_weights_tensor, |
252 | &recurrent_to_forget_weights)); |
253 | TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); |
254 | TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], |
255 | n_cell); |
256 | TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], |
257 | n_output); |
258 | TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type, |
259 | input_to_forget_weights->type); |
260 | |
261 | const TfLiteTensor* recurrent_to_cell_weights; |
262 | TF_LITE_ENSURE_OK( |
263 | context, GetInputSafe(context, node, recurrent_to_cell_weights_tensor, |
264 | &recurrent_to_cell_weights)); |
265 | TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); |
266 | TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); |
267 | TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], |
268 | n_output); |
269 | TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_cell_weights->type, |
270 | input_to_forget_weights->type); |
271 | |
272 | // We make sure the input-gate's parameters are either both present (regular |
273 | // LSTM) or not at all (CIFG-LSTM). |
274 | const bool cifg_weights_all_or_none = |
275 | ((input_to_input_weights != nullptr) && |
276 | (recurrent_to_input_weights != nullptr)) || |
277 | ((input_to_input_weights == nullptr) && |
278 | (recurrent_to_input_weights == nullptr)); |
279 | TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); |
280 | |
281 | const TfLiteTensor* cell_to_input_weights = |
282 | GetOptionalInputTensor(context, node, cell_to_input_weights_tensor); |
283 | if (cell_to_input_weights != nullptr) { |
284 | TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); |
285 | TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); |
286 | TF_LITE_ENSURE_TYPES_EQ(context, cell_to_input_weights->type, |
287 | input_to_forget_weights->type); |
288 | } |
289 | |
290 | const TfLiteTensor* cell_to_forget_weights = |
291 | GetOptionalInputTensor(context, node, cell_to_forget_weights_tensor); |
292 | if (cell_to_forget_weights != nullptr) { |
293 | TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); |
294 | TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); |
295 | TF_LITE_ENSURE_TYPES_EQ(context, cell_to_forget_weights->type, |
296 | input_to_forget_weights->type); |
297 | } |
298 | |
299 | const TfLiteTensor* cell_to_output_weights = |
300 | GetOptionalInputTensor(context, node, cell_to_output_weights_tensor); |
301 | if (cell_to_output_weights != nullptr) { |
302 | TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); |
303 | TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); |
304 | TF_LITE_ENSURE_TYPES_EQ(context, cell_to_output_weights->type, |
305 | input_to_forget_weights->type); |
306 | } |
307 | |
308 | // Making sure the peephole weights are there all or none. |
309 | const bool use_cifg = (input_to_input_weights == nullptr); |
310 | const bool peephole_weights_all_or_none = |
311 | ((cell_to_input_weights != nullptr || use_cifg) && |
312 | (cell_to_forget_weights != nullptr) && |
313 | (cell_to_output_weights != nullptr)) || |
314 | ((cell_to_input_weights == nullptr) && |
315 | (cell_to_forget_weights == nullptr) && |
316 | (cell_to_output_weights == nullptr)); |
317 | TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); |
318 | |
319 | // Make sure the input gate bias is present only when not a CIFG-LSTM. |
320 | const TfLiteTensor* input_gate_bias = |
321 | GetOptionalInputTensor(context, node, input_gate_bias_tensor); |
322 | if (use_cifg) { |
323 | TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); |
324 | } else { |
325 | TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); |
326 | TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); |
327 | TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32); |
328 | } |
329 | |
330 | const TfLiteTensor* forget_gate_bias; |
331 | TF_LITE_ENSURE_OK( |
332 | context, |
333 | GetInputSafe(context, node, forget_gate_bias_tensor, &forget_gate_bias)); |
334 | TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); |
335 | TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); |
336 | TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32); |
337 | |
338 | const TfLiteTensor* cell_gate_bias; |
339 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, cell_gate_bias_tensor, |
340 | &cell_gate_bias)); |
341 | TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1); |
342 | TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell); |
343 | TF_LITE_ENSURE_EQ(context, cell_gate_bias->type, kTfLiteFloat32); |
344 | |
345 | const TfLiteTensor* output_gate_bias; |
346 | TF_LITE_ENSURE_OK( |
347 | context, |
348 | GetInputSafe(context, node, output_gate_bias_tensor, &output_gate_bias)); |
349 | TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); |
350 | TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); |
351 | TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32); |
352 | |
353 | const TfLiteTensor* projection_weights = |
354 | GetOptionalInputTensor(context, node, projection_weights_tensor); |
355 | if (projection_weights != nullptr) { |
356 | TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); |
357 | TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); |
358 | TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); |
359 | TF_LITE_ENSURE_TYPES_EQ(context, projection_weights->type, |
360 | input_to_forget_weights->type); |
361 | } |
362 | |
363 | const TfLiteTensor* projection_bias = |
364 | GetOptionalInputTensor(context, node, projection_bias_tensor); |
365 | if (projection_bias != nullptr) { |
366 | TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); |
367 | TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); |
368 | TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32); |
369 | } |
370 | |
371 | // Making sure the projection tensors are consistent: |
372 | // 1) If projection weight is not present, then projection bias should not be |
373 | // present. |
374 | // 2) If projection weight is present, then projection bias is optional. |
375 | // TODO(ghodrat): make sure this is correct. |
376 | const bool projecton_tensors_consistent = |
377 | ((projection_weights != nullptr) || (projection_bias == nullptr)); |
378 | TF_LITE_ENSURE(context, projecton_tensors_consistent == true); |
379 | |
380 | return kTfLiteOk; |
381 | } |
382 | |
383 | TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, |
384 | TfLiteNode* node, int n_input, |
385 | int n_output, int n_cell) { |
386 | TF_LITE_ENSURE_OK( |
387 | context, |
388 | CheckLstmTensorDimensionsAndTypes( |
389 | context, node, n_input, n_output, n_cell, |
390 | kFwInputToInputWeightsTensor, kFwInputToForgetWeightsTensor, |
391 | kFwInputToCellWeightsTensor, kFwInputToOutputWeightsTensor, |
392 | kFwRecurrentToInputWeightsTensor, kFwRecurrentToForgetWeightsTensor, |
393 | kFwRecurrentToCellWeightsTensor, kFwRecurrentToOutputWeightsTensor, |
394 | kFwCellToInputWeightsTensor, kFwCellToForgetWeightsTensor, |
395 | kFwCellToOutputWeightsTensor, kFwInputGateBiasTensor, |
396 | kFwForgetGateBiasTensor, kFwCellGateBiasTensor, |
397 | kFwOutputGateBiasTensor, kFwProjectionWeightsTensor, |
398 | kFwProjectionBiasTensor)); |
399 | |
400 | TF_LITE_ENSURE_OK( |
401 | context, |
402 | CheckLstmTensorDimensionsAndTypes( |
403 | context, node, n_input, n_output, n_cell, |
404 | kBwInputToInputWeightsTensor, kBwInputToForgetWeightsTensor, |
405 | kBwInputToCellWeightsTensor, kBwInputToOutputWeightsTensor, |
406 | kBwRecurrentToInputWeightsTensor, kBwRecurrentToForgetWeightsTensor, |
407 | kBwRecurrentToCellWeightsTensor, kBwRecurrentToOutputWeightsTensor, |
408 | kBwCellToInputWeightsTensor, kBwCellToForgetWeightsTensor, |
409 | kBwCellToOutputWeightsTensor, kBwInputGateBiasTensor, |
410 | kBwForgetGateBiasTensor, kBwCellGateBiasTensor, |
411 | kBwOutputGateBiasTensor, kBwProjectionWeightsTensor, |
412 | kBwProjectionBiasTensor)); |
413 | |
414 | // Check if Forward and Backward tensors match along required dimensions. |
415 | return kTfLiteOk; |
416 | } |
417 | |
418 | // Resize the output and scratch tensors based on the sizes of the input |
419 | // tensors. Also check that the size of the input tensors match each other. |
420 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
421 | auto* op_data = reinterpret_cast<OpData*>(node->user_data); |
422 | const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( |
423 | node->builtin_data); |
424 | |
425 | // Check we have all the inputs and outputs we need. |
426 | TF_LITE_ENSURE_EQ(context, node->inputs->size, 48); |
427 | TF_LITE_ENSURE_EQ(context, node->outputs->size, |
428 | params->merge_outputs ? 1 : 2); |
429 | |
430 | // Inferring batch size, number of outputs and sequence length and |
431 | // number of cells from the input tensors. |
432 | const TfLiteTensor* input; |
433 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
434 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); |
435 | TF_LITE_ENSURE_EQ(context, input->dims->size, 3); |
436 | const bool time_major = params->time_major; |
437 | const int max_time = time_major ? input->dims->data[0] : input->dims->data[1]; |
438 | const int n_batch = time_major ? input->dims->data[1] : input->dims->data[0]; |
439 | const int n_input = input->dims->data[2]; |
440 | |
441 | const TfLiteTensor* fw_input_to_output_weights; |
442 | TF_LITE_ENSURE_OK(context, |
443 | GetInputSafe(context, node, kFwInputToOutputWeightsTensor, |
444 | &fw_input_to_output_weights)); |
445 | const int n_fw_cell = fw_input_to_output_weights->dims->data[0]; |
446 | TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2); |
447 | TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1], |
448 | n_input); |
449 | |
450 | const TfLiteTensor* bw_input_to_output_weights; |
451 | TF_LITE_ENSURE_OK(context, |
452 | GetInputSafe(context, node, kBwInputToOutputWeightsTensor, |
453 | &bw_input_to_output_weights)); |
454 | const int n_bw_cell = bw_input_to_output_weights->dims->data[0]; |
455 | TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2); |
456 | TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1], |
457 | n_input); |
458 | TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->type, |
459 | fw_input_to_output_weights->type); |
460 | |
461 | const TfLiteTensor* fw_recurrent_to_output_weights; |
462 | TF_LITE_ENSURE_OK( |
463 | context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor, |
464 | &fw_recurrent_to_output_weights)); |
465 | TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2); |
466 | TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0], |
467 | n_fw_cell); |
468 | TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->type, |
469 | fw_input_to_output_weights->type); |
470 | const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1]; |
471 | |
472 | const TfLiteTensor* bw_recurrent_to_output_weights; |
473 | TF_LITE_ENSURE_OK( |
474 | context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor, |
475 | &bw_recurrent_to_output_weights)); |
476 | TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2); |
477 | TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0], |
478 | n_bw_cell); |
479 | TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->type, |
480 | fw_input_to_output_weights->type); |
481 | const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1]; |
482 | |
483 | // Check that input tensor dimensions matches with each other. |
484 | TF_LITE_ENSURE_OK( |
485 | context, CheckInputTensorDimensions(context, node, n_input, n_fw_output, |
486 | n_fw_cell)); |
487 | |
488 | // Get (optional) auxiliary inputs and weights. |
489 | const TfLiteTensor* aux_input = |
490 | GetOptionalInputTensor(context, node, kAuxInputTensor); |
491 | const TfLiteTensor* fw_aux_input_to_input_weights = |
492 | GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor); |
493 | const TfLiteTensor* fw_aux_input_to_forget_weights = |
494 | GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor); |
495 | const TfLiteTensor* fw_aux_input_to_cell_weights = |
496 | GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor); |
497 | const TfLiteTensor* fw_aux_input_to_output_weights = |
498 | GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor); |
499 | const TfLiteTensor* bw_aux_input_to_input_weights = |
500 | GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor); |
501 | const TfLiteTensor* bw_aux_input_to_forget_weights = |
502 | GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor); |
503 | const TfLiteTensor* bw_aux_input_to_cell_weights = |
504 | GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor); |
505 | const TfLiteTensor* bw_aux_input_to_output_weights = |
506 | GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor); |
507 | |
508 | const bool aux_inputs_weights_all_or_none = |
509 | ((fw_aux_input_to_cell_weights != nullptr) && |
510 | (fw_aux_input_to_forget_weights != nullptr) && |
511 | (fw_aux_input_to_output_weights != nullptr) && |
512 | (bw_aux_input_to_cell_weights != nullptr) && |
513 | (bw_aux_input_to_forget_weights != nullptr) && |
514 | (bw_aux_input_to_output_weights != nullptr)) || |
515 | ((fw_aux_input_to_cell_weights == nullptr) && |
516 | (fw_aux_input_to_forget_weights == nullptr) && |
517 | (fw_aux_input_to_output_weights == nullptr) && |
518 | (bw_aux_input_to_cell_weights == nullptr) && |
519 | (bw_aux_input_to_forget_weights == nullptr) && |
520 | (bw_aux_input_to_output_weights == nullptr)); |
521 | TF_LITE_ENSURE(context, aux_inputs_weights_all_or_none); |
522 | |
523 | const bool has_aux_input = (fw_aux_input_to_forget_weights != nullptr); |
524 | |
525 | if (has_aux_input) { |
526 | // Check that aux_input has the same dimensions (except last) as the input. |
527 | TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]); |
528 | TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]); |
529 | } |
530 | |
531 | // Get the pointer to output, activation_state and cell_state buffer tensors. |
532 | TfLiteTensor* fw_output; |
533 | TF_LITE_ENSURE_OK(context, |
534 | GetOutputSafe(context, node, kFwOutputTensor, &fw_output)); |
535 | TfLiteTensor* fw_activation_state = |
536 | GetVariableInput(context, node, kFwInputActivationStateTensor); |
537 | TF_LITE_ENSURE(context, fw_activation_state != nullptr); |
538 | TfLiteTensor* fw_cell_state = |
539 | GetVariableInput(context, node, kFwInputCellStateTensor); |
540 | TF_LITE_ENSURE(context, fw_cell_state != nullptr); |
541 | |
542 | // Check the shape of input state tensors. |
543 | // These tensor may be 1D or 2D. It's fine as long as the total size is |
544 | // correct. |
545 | TF_LITE_ENSURE_EQ(context, NumElements(fw_activation_state), |
546 | n_batch * n_fw_output); |
547 | TF_LITE_ENSURE_EQ(context, NumElements(fw_cell_state), n_batch * n_fw_cell); |
548 | |
549 | // Resize the output tensors. |
550 | TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3); |
551 | fw_output_size->data[0] = time_major ? max_time : n_batch; |
552 | fw_output_size->data[1] = time_major ? n_batch : max_time; |
553 | fw_output_size->data[2] = |
554 | params->merge_outputs ? n_bw_output + n_fw_output : n_fw_output; |
555 | TF_LITE_ENSURE_OK(context, |
556 | context->ResizeTensor(context, fw_output, fw_output_size)); |
557 | |
558 | // The weights are of consistent type, so it suffices to check one. |
559 | const bool is_hybrid_op = IsHybridOp(input, fw_input_to_output_weights); |
560 | |
561 | TfLiteIntArrayFree(node->temporaries); |
562 | if (is_hybrid_op) { |
563 | node->temporaries = TfLiteIntArrayCreate( |
564 | has_aux_input ? kNumTemporaryTensors : kNumTemporaryTensors - 1); |
565 | } else { |
566 | node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers. |
567 | } |
568 | // Create a scratch buffer tensor. |
569 | node->temporaries->data[kFwScratchBuffer] = |
570 | op_data->scratch_tensor_index + kFwScratchBuffer; |
571 | TfLiteTensor* fw_scratch_buffer; |
572 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer, |
573 | &fw_scratch_buffer)); |
574 | fw_scratch_buffer->type = input->type; |
575 | fw_scratch_buffer->allocation_type = kTfLiteArenaRw; |
576 | |
577 | const TfLiteTensor* fw_input_to_input_weights = |
578 | GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor); |
579 | const bool fw_use_cifg = (fw_input_to_input_weights == nullptr); |
580 | if (has_aux_input && !fw_use_cifg) { |
581 | TF_LITE_ENSURE_EQ(context, fw_aux_input_to_input_weights->dims->data[0], |
582 | fw_input_to_input_weights->dims->data[0]); |
583 | } |
584 | TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2); |
585 | fw_scratch_buffer_size->data[0] = n_batch; |
586 | if (fw_use_cifg) { |
587 | // Reserving space for Cell, Forget, Output gates and scratch accumulation |
588 | // buffer and an extra 16 bytes to avoid internal ruy copies. |
589 | fw_scratch_buffer_size->data[1] = n_fw_cell * 4 + 16; |
590 | } else { |
591 | // Reserving space for Input, Cell, Forget, Output gates and scratch |
592 | // accumulation buffer and an extra 16 bytes to avoid internal ruy copies. |
593 | fw_scratch_buffer_size->data[1] = n_fw_cell * 5 + 16; |
594 | } |
595 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer, |
596 | fw_scratch_buffer_size)); |
597 | // Same for the backward cell. |
598 | |
599 | // Check that input tensor dimensions matches with each other. |
600 | TF_LITE_ENSURE_OK( |
601 | context, CheckInputTensorDimensions(context, node, n_input, n_bw_output, |
602 | n_bw_cell)); |
603 | |
604 | // Get the pointer to activation_state and cell_state buffer tensors. |
605 | TfLiteTensor* bw_activation_state = |
606 | GetVariableInput(context, node, kBwInputActivationStateTensor); |
607 | TF_LITE_ENSURE(context, bw_activation_state != nullptr); |
608 | TfLiteTensor* bw_cell_state = |
609 | GetVariableInput(context, node, kBwInputCellStateTensor); |
610 | TF_LITE_ENSURE(context, bw_cell_state != nullptr); |
611 | |
612 | // Resize the output tensors. |
613 | if (!params->merge_outputs) { |
614 | TfLiteTensor* bw_output; |
615 | TF_LITE_ENSURE_OK( |
616 | context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output)); |
617 | TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3); |
618 | bw_output_size->data[0] = time_major ? max_time : n_batch; |
619 | bw_output_size->data[1] = time_major ? n_batch : max_time; |
620 | bw_output_size->data[2] = n_bw_output; |
621 | TF_LITE_ENSURE_OK( |
622 | context, context->ResizeTensor(context, bw_output, bw_output_size)); |
623 | } |
624 | |
625 | // Check the shape of input state tensors. |
626 | // These tensor may be 1D or 2D. It's fine as long as the total size is |
627 | // correct. |
628 | TF_LITE_ENSURE_EQ(context, NumElements(bw_activation_state), |
629 | n_batch * n_bw_output); |
630 | TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell); |
631 | |
632 | // Create a scratch buffer tensor. |
633 | node->temporaries->data[kBwScratchBuffer] = |
634 | op_data->scratch_tensor_index + kBwScratchBuffer; |
635 | TfLiteTensor* bw_scratch_buffer; |
636 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer, |
637 | &bw_scratch_buffer)); |
638 | bw_scratch_buffer->type = input->type; |
639 | bw_scratch_buffer->allocation_type = kTfLiteArenaRw; |
640 | |
641 | const TfLiteTensor* bw_input_to_input_weights = |
642 | GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor); |
643 | const bool bw_use_cifg = (bw_input_to_input_weights == nullptr); |
644 | if (has_aux_input && !bw_use_cifg) { |
645 | TF_LITE_ENSURE_EQ(context, bw_aux_input_to_input_weights->dims->data[0], |
646 | bw_input_to_input_weights->dims->data[0]); |
647 | } |
648 | TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2); |
649 | bw_scratch_buffer_size->data[0] = n_batch; |
650 | if (bw_use_cifg) { |
651 | // Reserving space for Cell, Forget, Output gates and scratch accumulation |
652 | // buffer and an extra 16 bytes to avoid internal ruy copies. |
653 | bw_scratch_buffer_size->data[1] = n_bw_cell * 4; |
654 | } else { |
655 | // Reserving space for Input, Cell, Forget, Output gates and scratch |
656 | // accumulation buffer and an extra 16 bytes to avoid internal ruy copies. |
657 | bw_scratch_buffer_size->data[1] = n_bw_cell * 5; |
658 | } |
659 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer, |
660 | bw_scratch_buffer_size)); |
661 | if (is_hybrid_op) { |
662 | // Compute the row sums for cached zero_point offset calculation. |
663 | op_data->compute_fw_row_sums = true; |
664 | op_data->compute_bw_row_sums = true; |
665 | // Allocate temporary tensors to store quantized values of input, aux_input |
666 | // (if present), activation_state and cell_state tensors. |
667 | node->temporaries->data[kInputQuantized] = |
668 | op_data->scratch_tensor_index + kInputQuantized; |
669 | TfLiteTensor* input_quantized; |
670 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized, |
671 | &input_quantized)); |
672 | input_quantized->type = fw_input_to_output_weights->type; |
673 | input_quantized->allocation_type = kTfLiteArenaRw; |
674 | if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { |
675 | TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); |
676 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, |
677 | input_quantized_size)); |
678 | } |
679 | |
680 | node->temporaries->data[kFwActivationStateQuantized] = |
681 | op_data->scratch_tensor_index + kFwActivationStateQuantized; |
682 | TfLiteTensor* fw_activation_state_quantized; |
683 | TF_LITE_ENSURE_OK( |
684 | context, GetTemporarySafe(context, node, kFwActivationStateQuantized, |
685 | &fw_activation_state_quantized)); |
686 | fw_activation_state_quantized->type = fw_input_to_output_weights->type; |
687 | fw_activation_state_quantized->allocation_type = kTfLiteArenaRw; |
688 | if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims, |
689 | fw_activation_state->dims)) { |
690 | TfLiteIntArray* fw_activation_state_quantized_size = |
691 | TfLiteIntArrayCopy(fw_activation_state->dims); |
692 | TF_LITE_ENSURE_OK( |
693 | context, context->ResizeTensor(context, fw_activation_state_quantized, |
694 | fw_activation_state_quantized_size)); |
695 | } |
696 | node->temporaries->data[kBwActivationStateQuantized] = |
697 | op_data->scratch_tensor_index + kBwActivationStateQuantized; |
698 | TfLiteTensor* bw_activation_state_quantized; |
699 | TF_LITE_ENSURE_OK( |
700 | context, GetTemporarySafe(context, node, kBwActivationStateQuantized, |
701 | &bw_activation_state_quantized)); |
702 | bw_activation_state_quantized->type = fw_input_to_output_weights->type; |
703 | bw_activation_state_quantized->allocation_type = kTfLiteArenaRw; |
704 | if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims, |
705 | bw_activation_state->dims)) { |
706 | TfLiteIntArray* bw_activation_state_quantized_size = |
707 | TfLiteIntArrayCopy(bw_activation_state->dims); |
708 | TF_LITE_ENSURE_OK( |
709 | context, context->ResizeTensor(context, bw_activation_state_quantized, |
710 | bw_activation_state_quantized_size)); |
711 | } |
712 | node->temporaries->data[kFwCellStateQuantized] = |
713 | op_data->scratch_tensor_index + kFwCellStateQuantized; |
714 | TfLiteTensor* fw_cell_state_quantized; |
715 | TF_LITE_ENSURE_OK(context, |
716 | GetTemporarySafe(context, node, kFwCellStateQuantized, |
717 | &fw_cell_state_quantized)); |
718 | fw_cell_state_quantized->type = fw_input_to_output_weights->type; |
719 | fw_cell_state_quantized->allocation_type = kTfLiteArenaRw; |
720 | if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims, |
721 | fw_cell_state->dims)) { |
722 | TfLiteIntArray* fw_cell_state_quantized_size = |
723 | TfLiteIntArrayCopy(fw_cell_state->dims); |
724 | TF_LITE_ENSURE_OK(context, |
725 | context->ResizeTensor(context, fw_cell_state_quantized, |
726 | fw_cell_state_quantized_size)); |
727 | } |
728 | node->temporaries->data[kBwCellStateQuantized] = |
729 | op_data->scratch_tensor_index + kBwCellStateQuantized; |
730 | TfLiteTensor* bw_cell_state_quantized; |
731 | TF_LITE_ENSURE_OK(context, |
732 | GetTemporarySafe(context, node, kBwCellStateQuantized, |
733 | &bw_cell_state_quantized)); |
734 | bw_cell_state_quantized->type = fw_input_to_output_weights->type; |
735 | bw_cell_state_quantized->allocation_type = kTfLiteArenaRw; |
736 | if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims, |
737 | bw_cell_state->dims)) { |
738 | TfLiteIntArray* bw_cell_state_quantized_size = |
739 | TfLiteIntArrayCopy(bw_cell_state->dims); |
740 | TF_LITE_ENSURE_OK(context, |
741 | context->ResizeTensor(context, bw_cell_state_quantized, |
742 | bw_cell_state_quantized_size)); |
743 | } |
744 | |
745 | // Allocate temporary tensors to store scaling factors and product scaling |
746 | // factors. The latter is a convenience storage which allows to quantize |
747 | // a vector once (which produces the scaling factors) and multiply it with |
748 | // different matrices (which requires multiplying the scaling factors with |
749 | // the scaling factor of the matrix). |
750 | node->temporaries->data[kInputScalingFactors] = |
751 | op_data->scratch_tensor_index + kInputScalingFactors; |
752 | TfLiteTensor* input_sf; |
753 | TF_LITE_ENSURE_OK( |
754 | context, |
755 | GetTemporarySafe(context, node, kInputScalingFactors, &input_sf)); |
756 | input_sf->type = kTfLiteFloat32; |
757 | input_sf->allocation_type = kTfLiteArenaRw; |
758 | int scaling_dims[1] = {n_batch}; |
759 | if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) { |
760 | TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1); |
761 | input_sf_size->data[0] = n_batch; |
762 | TF_LITE_ENSURE_OK( |
763 | context, context->ResizeTensor(context, input_sf, input_sf_size)); |
764 | } |
765 | node->temporaries->data[kAuxInputScalingFactors] = |
766 | op_data->scratch_tensor_index + kAuxInputScalingFactors; |
767 | TfLiteTensor* aux_input_sf; |
768 | TF_LITE_ENSURE_OK(context, |
769 | GetTemporarySafe(context, node, kAuxInputScalingFactors, |
770 | &aux_input_sf)); |
771 | aux_input_sf->type = kTfLiteFloat32; |
772 | aux_input_sf->allocation_type = kTfLiteArenaRw; |
773 | if (!TfLiteIntArrayEqualsArray(aux_input_sf->dims, 1, scaling_dims)) { |
774 | TfLiteIntArray* aux_input_sf_size = TfLiteIntArrayCreate(1); |
775 | aux_input_sf_size->data[0] = n_batch; |
776 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_sf, |
777 | aux_input_sf_size)); |
778 | } |
779 | node->temporaries->data[kOutputStateScalingFactors] = |
780 | op_data->scratch_tensor_index + kOutputStateScalingFactors; |
781 | TfLiteTensor* output_state_sf; |
782 | TF_LITE_ENSURE_OK( |
783 | context, GetTemporarySafe(context, node, kOutputStateScalingFactors, |
784 | &output_state_sf)); |
785 | output_state_sf->type = kTfLiteFloat32; |
786 | output_state_sf->allocation_type = kTfLiteArenaRw; |
787 | if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) { |
788 | TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1); |
789 | output_state_sf_size->data[0] = n_batch; |
790 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf, |
791 | output_state_sf_size)); |
792 | } |
793 | node->temporaries->data[kProductScalingFactors] = |
794 | op_data->scratch_tensor_index + kProductScalingFactors; |
795 | TfLiteTensor* prod_scaling_factors; |
796 | TF_LITE_ENSURE_OK(context, |
797 | GetTemporarySafe(context, node, kProductScalingFactors, |
798 | &prod_scaling_factors)); |
799 | prod_scaling_factors->type = kTfLiteFloat32; |
800 | prod_scaling_factors->allocation_type = kTfLiteArenaRw; |
801 | if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1, |
802 | scaling_dims)) { |
803 | TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); |
804 | prod_scaling_factors_size->data[0] = n_batch; |
805 | TF_LITE_ENSURE_OK(context, |
806 | context->ResizeTensor(context, prod_scaling_factors, |
807 | prod_scaling_factors_size)); |
808 | } |
809 | |
810 | // Allocate a temporary tensor to store the recovered cell weights. Since |
811 | // this is used for diagonal matrices, only need to store n_cell values. |
812 | node->temporaries->data[kRecoveredCellWeights] = |
813 | op_data->scratch_tensor_index + kRecoveredCellWeights; |
814 | TfLiteTensor* recovered_cell_weights; |
815 | TF_LITE_ENSURE_OK(context, |
816 | GetTemporarySafe(context, node, kRecoveredCellWeights, |
817 | &recovered_cell_weights)); |
818 | recovered_cell_weights->type = kTfLiteFloat32; |
819 | recovered_cell_weights->allocation_type = kTfLiteArenaRw; |
820 | int recovered_cell_dims[1] = {n_fw_cell}; |
821 | if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1, |
822 | recovered_cell_dims)) { |
823 | TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); |
824 | recovered_cell_weights_size->data[0] = n_fw_cell; |
825 | TF_LITE_ENSURE_OK(context, |
826 | context->ResizeTensor(context, recovered_cell_weights, |
827 | recovered_cell_weights_size)); |
828 | } |
829 | |
830 | // Allocate a temporary tensor to store the accumulated int32 values. |
831 | node->temporaries->data[kAccumScratchBuffer] = |
832 | op_data->scratch_tensor_index + kAccumScratchBuffer; |
833 | TfLiteTensor* accum_scratch; |
834 | TF_LITE_ENSURE_OK( |
835 | context, |
836 | GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch)); |
837 | accum_scratch->type = kTfLiteInt32; |
838 | accum_scratch->allocation_type = kTfLiteArenaRw; |
839 | int n_cell = std::max(n_fw_cell, n_bw_cell); |
840 | if (has_aux_input) { |
841 | n_cell = std::max(n_cell, fw_aux_input_to_output_weights->dims->data[0]); |
842 | n_cell = std::max(n_cell, bw_aux_input_to_output_weights->dims->data[0]); |
843 | } |
844 | int accum_scratch_dims[2] = {n_cell, n_batch}; |
845 | if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, |
846 | accum_scratch_dims)) { |
847 | TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2); |
848 | accum_size->data[0] = n_cell; |
849 | accum_size->data[1] = n_batch; |
850 | TF_LITE_ENSURE_OK( |
851 | context, context->ResizeTensor(context, accum_scratch, accum_size)); |
852 | } |
853 | |
854 | // Allocate temporary tensors for storing zero-points. |
855 | node->temporaries->data[kInputZeroPoints] = |
856 | op_data->scratch_tensor_index + kInputZeroPoints; |
857 | TfLiteTensor* input_zp; |
858 | TF_LITE_ENSURE_OK( |
859 | context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp)); |
860 | input_zp->type = kTfLiteFloat32; |
861 | input_zp->allocation_type = kTfLiteArenaRw; |
862 | if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) { |
863 | TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1); |
864 | input_zp_size->data[0] = n_batch; |
865 | TF_LITE_ENSURE_OK( |
866 | context, context->ResizeTensor(context, input_zp, input_zp_size)); |
867 | } |
868 | node->temporaries->data[kAuxInputZeroPoints] = |
869 | op_data->scratch_tensor_index + kAuxInputZeroPoints; |
870 | TfLiteTensor* aux_input_zp; |
871 | TF_LITE_ENSURE_OK( |
872 | context, |
873 | GetTemporarySafe(context, node, kAuxInputZeroPoints, &aux_input_zp)); |
874 | aux_input_zp->type = kTfLiteFloat32; |
875 | aux_input_zp->allocation_type = kTfLiteArenaRw; |
876 | if (!TfLiteIntArrayEqualsArray(aux_input_zp->dims, 1, scaling_dims)) { |
877 | TfLiteIntArray* aux_input_zp_size = TfLiteIntArrayCreate(1); |
878 | aux_input_zp_size->data[0] = n_batch; |
879 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, aux_input_zp, |
880 | aux_input_zp_size)); |
881 | } |
882 | node->temporaries->data[kOutputStateZeroPoints] = |
883 | op_data->scratch_tensor_index + kOutputStateZeroPoints; |
884 | TfLiteTensor* output_state_zp; |
885 | TF_LITE_ENSURE_OK(context, |
886 | GetTemporarySafe(context, node, kOutputStateZeroPoints, |
887 | &output_state_zp)); |
888 | output_state_zp->type = kTfLiteFloat32; |
889 | output_state_zp->allocation_type = kTfLiteArenaRw; |
890 | if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) { |
891 | TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1); |
892 | output_state_zp_size->data[0] = n_batch; |
893 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp, |
894 | output_state_zp_size)); |
895 | } |
896 | |
897 | // Allocate temporary tensors for caching row sums for hybrid zero-point |
898 | // calculations. |
899 | int fw_row_sums_rows = fw_use_cifg ? 6 : 8; |
900 | if (has_aux_input) { |
901 | fw_row_sums_rows += fw_use_cifg ? 3 : 4; |
902 | } |
903 | const TfLiteTensor* fw_projection_weights = |
904 | GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor); |
905 | if (fw_projection_weights != nullptr) { |
906 | fw_row_sums_rows += ceil(static_cast<float>(n_fw_output) / n_fw_cell); |
907 | } |
908 | node->temporaries->data[kFwRowSums] = |
909 | op_data->scratch_tensor_index + kFwRowSums; |
910 | TfLiteTensor* fw_row_sums; |
911 | TF_LITE_ENSURE_OK( |
912 | context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums)); |
913 | fw_row_sums->type = kTfLiteInt32; |
914 | fw_row_sums->allocation_type = kTfLiteArenaRwPersistent; |
915 | int fw_row_sums_dims[2] = {fw_row_sums_rows, n_fw_cell}; |
916 | if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) { |
917 | TfLiteIntArray* fw_hybrid_scratch_size = TfLiteIntArrayCreate(2); |
918 | fw_hybrid_scratch_size->data[0] = fw_row_sums_dims[0]; |
919 | fw_hybrid_scratch_size->data[1] = fw_row_sums_dims[1]; |
920 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums, |
921 | fw_hybrid_scratch_size)); |
922 | } |
923 | |
924 | int bw_row_sums_rows = bw_use_cifg ? 6 : 8; |
925 | if (has_aux_input) { |
926 | bw_row_sums_rows += bw_use_cifg ? 3 : 4; |
927 | } |
928 | const TfLiteTensor* bw_projection_weights = |
929 | GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor); |
930 | if (bw_projection_weights != nullptr) { |
931 | bw_row_sums_rows += ceil(static_cast<float>(n_bw_output) / n_bw_cell); |
932 | } |
933 | node->temporaries->data[kBwRowSums] = |
934 | op_data->scratch_tensor_index + kBwRowSums; |
935 | TfLiteTensor* bw_row_sums; |
936 | TF_LITE_ENSURE_OK( |
937 | context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums)); |
938 | bw_row_sums->type = kTfLiteInt32; |
939 | bw_row_sums->allocation_type = kTfLiteArenaRwPersistent; |
940 | int bw_row_sums_dims[2] = {bw_row_sums_rows, n_bw_cell}; |
941 | if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) { |
942 | TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2); |
943 | bw_row_sums_size->data[0] = bw_row_sums_dims[0]; |
944 | bw_row_sums_size->data[1] = bw_row_sums_dims[1]; |
945 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums, |
946 | bw_row_sums_size)); |
947 | } |
948 | |
949 | // Only allocate a temporary tensor for quantized auxiliary input if we are |
950 | // actually going to use it. |
951 | if (has_aux_input) { |
952 | node->temporaries->data[kAuxInputQuantized] = |
953 | op_data->scratch_tensor_index + kAuxInputQuantized; |
954 | TfLiteTensor* aux_input_quantized; |
955 | TF_LITE_ENSURE_OK(context, |
956 | GetTemporarySafe(context, node, kAuxInputQuantized, |
957 | &aux_input_quantized)); |
958 | aux_input_quantized->type = fw_input_to_output_weights->type; |
959 | aux_input_quantized->allocation_type = kTfLiteArenaRw; |
960 | if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) { |
961 | TfLiteIntArray* aux_input_quantized_size = |
962 | TfLiteIntArrayCopy(aux_input->dims); |
963 | TF_LITE_ENSURE_OK(context, |
964 | context->ResizeTensor(context, aux_input_quantized, |
965 | aux_input_quantized_size)); |
966 | } |
967 | } |
968 | } |
969 | return kTfLiteOk; |
970 | } |
971 | |
972 | // The LSTM Op engine. |
973 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
974 | const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( |
975 | node->builtin_data); |
976 | auto* op_data = reinterpret_cast<OpData*>(node->user_data); |
977 | // Input tensor. |
978 | const TfLiteTensor* input; |
979 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
980 | |
981 | // Tensors for the forward cell. |
982 | const TfLiteTensor* fw_input_to_input_weights = |
983 | GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor); |
984 | const TfLiteTensor* fw_input_to_forget_weights; |
985 | TF_LITE_ENSURE_OK(context, |
986 | GetInputSafe(context, node, kFwInputToForgetWeightsTensor, |
987 | &fw_input_to_forget_weights)); |
988 | const TfLiteTensor* fw_input_to_cell_weights; |
989 | TF_LITE_ENSURE_OK(context, |
990 | GetInputSafe(context, node, kFwInputToCellWeightsTensor, |
991 | &fw_input_to_cell_weights)); |
992 | const TfLiteTensor* fw_input_to_output_weights; |
993 | TF_LITE_ENSURE_OK(context, |
994 | GetInputSafe(context, node, kFwInputToOutputWeightsTensor, |
995 | &fw_input_to_output_weights)); |
996 | |
997 | const TfLiteTensor* fw_recurrent_to_input_weights = |
998 | GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor); |
999 | const TfLiteTensor* fw_recurrent_to_forget_weights; |
1000 | TF_LITE_ENSURE_OK( |
1001 | context, GetInputSafe(context, node, kFwRecurrentToForgetWeightsTensor, |
1002 | &fw_recurrent_to_forget_weights)); |
1003 | const TfLiteTensor* fw_recurrent_to_cell_weights; |
1004 | TF_LITE_ENSURE_OK(context, |
1005 | GetInputSafe(context, node, kFwRecurrentToCellWeightsTensor, |
1006 | &fw_recurrent_to_cell_weights)); |
1007 | const TfLiteTensor* fw_recurrent_to_output_weights; |
1008 | TF_LITE_ENSURE_OK( |
1009 | context, GetInputSafe(context, node, kFwRecurrentToOutputWeightsTensor, |
1010 | &fw_recurrent_to_output_weights)); |
1011 | |
1012 | const TfLiteTensor* fw_cell_to_input_weights = |
1013 | GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor); |
1014 | const TfLiteTensor* fw_cell_to_forget_weights = |
1015 | GetOptionalInputTensor(context, node, kFwCellToForgetWeightsTensor); |
1016 | const TfLiteTensor* fw_cell_to_output_weights = |
1017 | GetOptionalInputTensor(context, node, kFwCellToOutputWeightsTensor); |
1018 | |
1019 | const TfLiteTensor* fw_input_gate_bias = |
1020 | GetOptionalInputTensor(context, node, kFwInputGateBiasTensor); |
1021 | const TfLiteTensor* fw_forget_gate_bias; |
1022 | TF_LITE_ENSURE_OK(context, |
1023 | GetInputSafe(context, node, kFwForgetGateBiasTensor, |
1024 | &fw_forget_gate_bias)); |
1025 | const TfLiteTensor* fw_cell_gate_bias; |
1026 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwCellGateBiasTensor, |
1027 | &fw_cell_gate_bias)); |
1028 | const TfLiteTensor* fw_output_gate_bias; |
1029 | TF_LITE_ENSURE_OK(context, |
1030 | GetInputSafe(context, node, kFwOutputGateBiasTensor, |
1031 | &fw_output_gate_bias)); |
1032 | |
1033 | const TfLiteTensor* fw_projection_weights = |
1034 | GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor); |
1035 | const TfLiteTensor* fw_projection_bias = |
1036 | GetOptionalInputTensor(context, node, kFwProjectionBiasTensor); |
1037 | |
1038 | TfLiteTensor* fw_activation_state = |
1039 | GetVariableInput(context, node, kFwInputActivationStateTensor); |
1040 | TFLITE_DCHECK(fw_activation_state != nullptr); |
1041 | TfLiteTensor* fw_cell_state = |
1042 | GetVariableInput(context, node, kFwInputCellStateTensor); |
1043 | TFLITE_DCHECK(fw_cell_state != nullptr); |
1044 | TfLiteTensor* fw_output; |
1045 | TF_LITE_ENSURE_OK(context, |
1046 | GetOutputSafe(context, node, kFwOutputTensor, &fw_output)); |
1047 | |
1048 | // Tensors for the backward cell. |
1049 | const TfLiteTensor* bw_input_to_input_weights = |
1050 | GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor); |
1051 | const TfLiteTensor* bw_input_to_forget_weights; |
1052 | TF_LITE_ENSURE_OK(context, |
1053 | GetInputSafe(context, node, kBwInputToForgetWeightsTensor, |
1054 | &bw_input_to_forget_weights)); |
1055 | const TfLiteTensor* bw_input_to_cell_weights; |
1056 | TF_LITE_ENSURE_OK(context, |
1057 | GetInputSafe(context, node, kBwInputToCellWeightsTensor, |
1058 | &bw_input_to_cell_weights)); |
1059 | const TfLiteTensor* bw_input_to_output_weights; |
1060 | TF_LITE_ENSURE_OK(context, |
1061 | GetInputSafe(context, node, kBwInputToOutputWeightsTensor, |
1062 | &bw_input_to_output_weights)); |
1063 | |
1064 | const TfLiteTensor* bw_recurrent_to_input_weights = |
1065 | GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor); |
1066 | const TfLiteTensor* bw_recurrent_to_forget_weights; |
1067 | TF_LITE_ENSURE_OK( |
1068 | context, GetInputSafe(context, node, kBwRecurrentToForgetWeightsTensor, |
1069 | &bw_recurrent_to_forget_weights)); |
1070 | const TfLiteTensor* bw_recurrent_to_cell_weights; |
1071 | TF_LITE_ENSURE_OK(context, |
1072 | GetInputSafe(context, node, kBwRecurrentToCellWeightsTensor, |
1073 | &bw_recurrent_to_cell_weights)); |
1074 | const TfLiteTensor* bw_recurrent_to_output_weights; |
1075 | TF_LITE_ENSURE_OK( |
1076 | context, GetInputSafe(context, node, kBwRecurrentToOutputWeightsTensor, |
1077 | &bw_recurrent_to_output_weights)); |
1078 | |
1079 | const TfLiteTensor* bw_cell_to_input_weights = |
1080 | GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor); |
1081 | const TfLiteTensor* bw_cell_to_forget_weights = |
1082 | GetOptionalInputTensor(context, node, kBwCellToForgetWeightsTensor); |
1083 | const TfLiteTensor* bw_cell_to_output_weights = |
1084 | GetOptionalInputTensor(context, node, kBwCellToOutputWeightsTensor); |
1085 | |
1086 | const TfLiteTensor* bw_input_gate_bias = |
1087 | GetOptionalInputTensor(context, node, kBwInputGateBiasTensor); |
1088 | const TfLiteTensor* bw_forget_gate_bias; |
1089 | TF_LITE_ENSURE_OK(context, |
1090 | GetInputSafe(context, node, kBwForgetGateBiasTensor, |
1091 | &bw_forget_gate_bias)); |
1092 | const TfLiteTensor* bw_cell_gate_bias; |
1093 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwCellGateBiasTensor, |
1094 | &bw_cell_gate_bias)); |
1095 | const TfLiteTensor* bw_output_gate_bias; |
1096 | TF_LITE_ENSURE_OK(context, |
1097 | GetInputSafe(context, node, kBwOutputGateBiasTensor, |
1098 | &bw_output_gate_bias)); |
1099 | |
1100 | const TfLiteTensor* bw_projection_weights = |
1101 | GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor); |
1102 | const TfLiteTensor* bw_projection_bias = |
1103 | GetOptionalInputTensor(context, node, kBwProjectionBiasTensor); |
1104 | |
1105 | // State tensors. |
1106 | TfLiteTensor* bw_activation_state = |
1107 | GetVariableInput(context, node, kBwInputActivationStateTensor); |
1108 | TFLITE_DCHECK(bw_activation_state != nullptr); |
1109 | TfLiteTensor* bw_cell_state = |
1110 | GetVariableInput(context, node, kBwInputCellStateTensor); |
1111 | TFLITE_DCHECK(bw_cell_state != nullptr); |
1112 | TfLiteTensor* bw_output = params->merge_outputs |
1113 | ? nullptr |
1114 | : GetOutput(context, node, kBwOutputTensor); |
1115 | |
1116 | // Temporary tensors. |
1117 | TfLiteTensor* fw_scratch_buffer; |
1118 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kFwScratchBuffer, |
1119 | &fw_scratch_buffer)); |
1120 | TfLiteTensor* bw_scratch_buffer; |
1121 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kBwScratchBuffer, |
1122 | &bw_scratch_buffer)); |
1123 | |
1124 | // (Optional) auxiliary inputs. |
1125 | const TfLiteTensor* aux_input = |
1126 | GetOptionalInputTensor(context, node, kAuxInputTensor); |
1127 | const TfLiteTensor* fw_aux_input_to_input_weights = |
1128 | GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor); |
1129 | const TfLiteTensor* fw_aux_input_to_forget_weights = |
1130 | GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor); |
1131 | const TfLiteTensor* fw_aux_input_to_cell_weights = |
1132 | GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor); |
1133 | const TfLiteTensor* fw_aux_input_to_output_weights = |
1134 | GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor); |
1135 | const TfLiteTensor* bw_aux_input_to_input_weights = |
1136 | GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor); |
1137 | const TfLiteTensor* bw_aux_input_to_forget_weights = |
1138 | GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor); |
1139 | const TfLiteTensor* bw_aux_input_to_cell_weights = |
1140 | GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor); |
1141 | const TfLiteTensor* bw_aux_input_to_output_weights = |
1142 | GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor); |
1143 | |
1144 | const bool has_previous_bw_output = (aux_input != nullptr); |
1145 | const bool use_aux_input = (fw_aux_input_to_forget_weights != nullptr); |
1146 | |
1147 | // Populate a TfLiteLSTMParams struct for the evaluation functions. |
1148 | TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip, |
1149 | params->proj_clip, kTfLiteLSTMFullKernel, |
1150 | params->asymmetric_quantize_inputs}; |
1151 | |
1152 | const int bw_output_offset = |
1153 | params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0; |
1154 | const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output; |
1155 | |
1156 | const bool time_major = params->time_major; |
1157 | |
1158 | // We want to cover the following cases: |
1159 | // |
1160 | // If not stacking (not connected after other bidi lstms): |
1161 | // both fw & bw will just use `input`; aux_input will be null. |
1162 | // |
1163 | // If stacking with cross_links, TensorFlow equivalent |
1164 | // (tf.contrib.rnn.stack_bidirectional_rnn): |
1165 | // both fw & bw will use `input`, but aux_input will be none null. |
1166 | // Note, this time, whether connected after other bidi lstms both works. |
1167 | // |
1168 | // If stacking without cross_links, but connected after other bidi lstms, |
1169 | // TensorFlow equivalent (tf.nn.static_bidirectional_rnn): |
1170 | // fw will use `input`, bw will use aux_input, and the `real aux_input` |
1171 | // will be null. |
1172 | |
1173 | const bool non_stacking_mode = !use_aux_input && has_previous_bw_output; |
1174 | const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input; |
1175 | const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input; |
1176 | |
1177 | switch (fw_input_to_output_weights->type) { |
1178 | case kTfLiteFloat32: { |
1179 | TfLiteStatus fw_pass_status = lstm_eval::EvalFloat( |
1180 | input, fw_input_to_input_weights, fw_input_to_forget_weights, |
1181 | fw_input_to_cell_weights, fw_input_to_output_weights, |
1182 | fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, |
1183 | fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, |
1184 | fw_cell_to_input_weights, fw_cell_to_forget_weights, |
1185 | fw_cell_to_output_weights, |
1186 | /*input_layer_norm_coefficients=*/nullptr, |
1187 | /*forget_layer_norm_coefficients=*/nullptr, |
1188 | /*cell_layer_norm_coefficients=*/nullptr, |
1189 | /*output_layer_norm_coefficients=*/nullptr, real_aux_input, |
1190 | fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights, |
1191 | fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights, |
1192 | fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias, |
1193 | fw_output_gate_bias, fw_projection_weights, fw_projection_bias, |
1194 | &lstm_params, |
1195 | /*forward_sequence=*/true, time_major, /*output_offset=*/0, |
1196 | fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output, |
1197 | CpuBackendContext::GetFromContext(context)); |
1198 | TF_LITE_ENSURE_OK(context, fw_pass_status); |
1199 | |
1200 | TfLiteStatus bw_pass_status = lstm_eval::EvalFloat( |
1201 | bw_input, bw_input_to_input_weights, bw_input_to_forget_weights, |
1202 | bw_input_to_cell_weights, bw_input_to_output_weights, |
1203 | bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, |
1204 | bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, |
1205 | bw_cell_to_input_weights, bw_cell_to_forget_weights, |
1206 | bw_cell_to_output_weights, |
1207 | /*input_layer_norm_coefficients=*/nullptr, |
1208 | /*forget_layer_norm_coefficients=*/nullptr, |
1209 | /*cell_layer_norm_coefficients=*/nullptr, |
1210 | /*output_layer_norm_coefficients=*/nullptr, real_aux_input, |
1211 | bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights, |
1212 | bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights, |
1213 | bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias, |
1214 | bw_output_gate_bias, bw_projection_weights, bw_projection_bias, |
1215 | &lstm_params, |
1216 | /*forward_sequence=*/false, time_major, bw_output_offset, |
1217 | bw_scratch_buffer, bw_activation_state, bw_cell_state, |
1218 | actual_bw_output, CpuBackendContext::GetFromContext(context)); |
1219 | TF_LITE_ENSURE_OK(context, bw_pass_status); |
1220 | return kTfLiteOk; |
1221 | } |
1222 | case kTfLiteUInt8: |
1223 | case kTfLiteInt8: { |
1224 | TfLiteTensor* input_quantized; |
1225 | TF_LITE_ENSURE_OK( |
1226 | context, |
1227 | GetTemporarySafe(context, node, kInputQuantized, &input_quantized)); |
1228 | TfLiteTensor* fw_activation_state_quantized; |
1229 | TF_LITE_ENSURE_OK( |
1230 | context, GetTemporarySafe(context, node, kFwActivationStateQuantized, |
1231 | &fw_activation_state_quantized)); |
1232 | TfLiteTensor* bw_activation_state_quantized; |
1233 | TF_LITE_ENSURE_OK( |
1234 | context, GetTemporarySafe(context, node, kBwActivationStateQuantized, |
1235 | &bw_activation_state_quantized)); |
1236 | TfLiteTensor* fw_cell_state_quantized; |
1237 | TF_LITE_ENSURE_OK(context, |
1238 | GetTemporarySafe(context, node, kFwCellStateQuantized, |
1239 | &fw_cell_state_quantized)); |
1240 | TfLiteTensor* bw_cell_state_quantized; |
1241 | TF_LITE_ENSURE_OK(context, |
1242 | GetTemporarySafe(context, node, kBwCellStateQuantized, |
1243 | &bw_cell_state_quantized)); |
1244 | TfLiteTensor* prod_scaling_factors; |
1245 | TF_LITE_ENSURE_OK(context, |
1246 | GetTemporarySafe(context, node, kProductScalingFactors, |
1247 | &prod_scaling_factors)); |
1248 | TfLiteTensor* recovered_cell_weights; |
1249 | TF_LITE_ENSURE_OK(context, |
1250 | GetTemporarySafe(context, node, kRecoveredCellWeights, |
1251 | &recovered_cell_weights)); |
1252 | TfLiteTensor* aux_input_quantized = |
1253 | use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) |
1254 | : nullptr; |
1255 | TfLiteTensor* accum_scratch; |
1256 | TF_LITE_ENSURE_OK( |
1257 | context, |
1258 | GetTemporarySafe(context, node, kAccumScratchBuffer, &accum_scratch)); |
1259 | TfLiteTensor* fw_row_sums; |
1260 | TF_LITE_ENSURE_OK( |
1261 | context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums)); |
1262 | TfLiteTensor* bw_row_sums; |
1263 | TF_LITE_ENSURE_OK( |
1264 | context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums)); |
1265 | const int fw_row_sums_size = fw_row_sums->dims->data[0]; |
1266 | const int bw_row_sums_size = bw_row_sums->dims->data[0]; |
1267 | TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid( |
1268 | input, fw_input_to_input_weights, |
1269 | /*input_to_input_weights_ledger*/ nullptr, fw_input_to_forget_weights, |
1270 | /*input_to_forget_weights_ledger*/ nullptr, fw_input_to_cell_weights, |
1271 | /*input_to_cell_weights_ledger*/ nullptr, fw_input_to_output_weights, |
1272 | /*input_to_output_weights_ledger*/ nullptr, |
1273 | fw_recurrent_to_input_weights, |
1274 | /*recurrent_to_input_weights_ledger*/ nullptr, |
1275 | fw_recurrent_to_forget_weights, |
1276 | /*recurrent_to_forget_weights_ledger*/ nullptr, |
1277 | fw_recurrent_to_cell_weights, |
1278 | /*recurrent_to_cell_weights_ledger*/ nullptr, |
1279 | fw_recurrent_to_output_weights, |
1280 | /*recurrent_to_output_weights_ledger*/ nullptr, |
1281 | fw_cell_to_input_weights, fw_cell_to_forget_weights, |
1282 | fw_cell_to_output_weights, |
1283 | /*input_layer_norm_coefficients=*/nullptr, |
1284 | /*forget_layer_norm_coefficients=*/nullptr, |
1285 | /*cell_layer_norm_coefficients=*/nullptr, |
1286 | /*output_layer_norm_coefficients=*/nullptr, real_aux_input, |
1287 | fw_aux_input_to_input_weights, fw_aux_input_to_forget_weights, |
1288 | fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights, |
1289 | fw_input_gate_bias, fw_forget_gate_bias, fw_cell_gate_bias, |
1290 | fw_output_gate_bias, fw_projection_weights, |
1291 | /*projection_weights_ledger*/ nullptr, fw_projection_bias, |
1292 | &lstm_params, |
1293 | /*forward_sequence=*/true, time_major, /*output_offset=*/0, |
1294 | fw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors), |
1295 | GetTemporary(context, node, kAuxInputScalingFactors), |
1296 | GetTemporary(context, node, kOutputStateScalingFactors), |
1297 | prod_scaling_factors, recovered_cell_weights, input_quantized, |
1298 | aux_input_quantized, fw_activation_state_quantized, |
1299 | fw_cell_state_quantized, fw_activation_state, fw_cell_state, |
1300 | accum_scratch, fw_output, |
1301 | GetTemporary(context, node, kInputZeroPoints), |
1302 | GetTemporary(context, node, kAuxInputZeroPoints), |
1303 | GetTemporary(context, node, kOutputStateZeroPoints), fw_row_sums, |
1304 | fw_row_sums_size, &op_data->compute_fw_row_sums, |
1305 | CpuBackendContext::GetFromContext(context)); |
1306 | TF_LITE_ENSURE_OK(context, fw_pass_status); |
1307 | |
1308 | TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid( |
1309 | bw_input, bw_input_to_input_weights, |
1310 | /*input_to_input_weights_ledger*/ nullptr, bw_input_to_forget_weights, |
1311 | /*input_to_forget_weights_ledger*/ nullptr, bw_input_to_cell_weights, |
1312 | /*input_to_cell_weights_ledger*/ nullptr, bw_input_to_output_weights, |
1313 | /*input_to_output_weights_ledger*/ nullptr, |
1314 | bw_recurrent_to_input_weights, |
1315 | /*recurrent_to_input_weights_ledger*/ nullptr, |
1316 | bw_recurrent_to_forget_weights, |
1317 | /*recurrent_to_forget_weights_ledger*/ nullptr, |
1318 | bw_recurrent_to_cell_weights, |
1319 | /*recurrent_to_cell_weights_ledger*/ nullptr, |
1320 | bw_recurrent_to_output_weights, |
1321 | /*recurrent_to_output_weights_ledger*/ nullptr, |
1322 | bw_cell_to_input_weights, bw_cell_to_forget_weights, |
1323 | bw_cell_to_output_weights, |
1324 | /*input_layer_norm_coefficients=*/nullptr, |
1325 | /*forget_layer_norm_coefficients=*/nullptr, |
1326 | /*cell_layer_norm_coefficients=*/nullptr, |
1327 | /*output_layer_norm_coefficients=*/nullptr, real_aux_input, |
1328 | bw_aux_input_to_input_weights, bw_aux_input_to_forget_weights, |
1329 | bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights, |
1330 | bw_input_gate_bias, bw_forget_gate_bias, bw_cell_gate_bias, |
1331 | bw_output_gate_bias, bw_projection_weights, |
1332 | /*projection_weights_ledger*/ nullptr, bw_projection_bias, |
1333 | &lstm_params, |
1334 | /*forward_sequence=*/false, time_major, bw_output_offset, |
1335 | bw_scratch_buffer, GetTemporary(context, node, kInputScalingFactors), |
1336 | GetTemporary(context, node, kAuxInputScalingFactors), |
1337 | GetTemporary(context, node, kOutputStateScalingFactors), |
1338 | prod_scaling_factors, recovered_cell_weights, input_quantized, |
1339 | aux_input_quantized, bw_activation_state_quantized, |
1340 | bw_cell_state_quantized, bw_activation_state, bw_cell_state, |
1341 | accum_scratch, actual_bw_output, |
1342 | GetTemporary(context, node, kInputZeroPoints), |
1343 | GetTemporary(context, node, kAuxInputZeroPoints), |
1344 | GetTemporary(context, node, kOutputStateZeroPoints), bw_row_sums, |
1345 | bw_row_sums_size, &op_data->compute_bw_row_sums, |
1346 | CpuBackendContext::GetFromContext(context)); |
1347 | TF_LITE_ENSURE_OK(context, bw_pass_status); |
1348 | return kTfLiteOk; |
1349 | } |
1350 | default: |
1351 | TF_LITE_KERNEL_LOG(context, "Type %s is not currently supported." , |
1352 | TfLiteTypeGetName(fw_input_to_output_weights->type)); |
1353 | return kTfLiteError; |
1354 | } |
1355 | } |
1356 | |
1357 | } // namespace bidirectional_sequence_lstm |
1358 | |
1359 | TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() { |
1360 | static TfLiteRegistration r = { |
1361 | bidirectional_sequence_lstm::Init, bidirectional_sequence_lstm::Free, |
1362 | bidirectional_sequence_lstm::Prepare, bidirectional_sequence_lstm::Eval}; |
1363 | return &r; |
1364 | } |
1365 | |
1366 | } // namespace builtin |
1367 | } // namespace ops |
1368 | } // namespace tflite |
1369 | |