1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <algorithm> |
16 | #include <cstddef> |
17 | #include <cstdint> |
18 | |
19 | #include "tensorflow/lite/c/builtin_op_data.h" |
20 | #include "tensorflow/lite/c/common.h" |
21 | #include "tensorflow/lite/kernels/internal/kernel_utils.h" |
22 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
23 | #include "tensorflow/lite/kernels/kernel_util.h" |
24 | #include "tensorflow/lite/kernels/op_macros.h" |
25 | |
26 | namespace tflite { |
27 | namespace ops { |
28 | namespace builtin { |
29 | namespace bidirectional_sequence_rnn { |
30 | |
31 | namespace { |
32 | |
33 | struct OpData { |
34 | int scratch_tensor_index; |
35 | bool fw_compute_row_sums = false; |
36 | bool bw_compute_row_sums = false; |
37 | }; |
38 | |
39 | } // namespace |
40 | |
41 | // LINT.IfChange |
42 | |
43 | constexpr int kInputTensor = 0; |
44 | // Forward and backward cell tensors. |
45 | constexpr int kFwWeightsTensor = 1; |
46 | constexpr int kFwRecurrentWeightsTensor = 2; |
47 | constexpr int kFwBiasTensor = 3; |
48 | constexpr int kFwHiddenStateTensor = 4; |
49 | constexpr int kBwWeightsTensor = 5; |
50 | constexpr int kBwRecurrentWeightsTensor = 6; |
51 | constexpr int kBwBiasTensor = 7; |
52 | constexpr int kBwHiddenStateTensor = 8; |
53 | // Used as auxiliary input and weights when stacking for |
54 | // tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input |
55 | // to the backward cell when stacking for tf.nn.static_bidirectional_rnn case |
56 | // (without cross links). |
57 | constexpr int kAuxInputTensor = 9; // Optional. |
58 | constexpr int kFwAuxWeightsTensor = 10; // Optional. |
59 | constexpr int kBwAuxWeightsTensor = 11; // Optional. |
60 | // Output tensors. |
61 | constexpr int kFwOutputTensor = 0; |
62 | constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false. |
63 | |
64 | // LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc) |
65 | |
66 | // Temporary tensors. |
67 | enum TemporaryTensor { |
68 | kInputQuantized = 0, |
69 | kFwHiddenStateQuantized = 1, |
70 | kBwHiddenStateQuantized = 2, |
71 | kScalingFactors = 3, |
72 | kAccumScratch = 4, |
73 | kZeroPoints = 5, |
74 | kFwRowSums = 6, |
75 | kBwRowSums = 7, |
76 | kAuxInputQuantized = 8, |
77 | kNumTemporaryTensors = 9 |
78 | }; |
79 | |
80 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
81 | auto* op_data = new OpData(); |
82 | context->AddTensors(context, kNumTemporaryTensors, |
83 | &op_data->scratch_tensor_index); |
84 | return op_data; |
85 | } |
86 | |
87 | void Free(TfLiteContext* context, void* buffer) { |
88 | delete reinterpret_cast<OpData*>(buffer); |
89 | } |
90 | |
91 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
92 | const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>( |
93 | node->builtin_data); |
94 | |
95 | // Check we have all the inputs and outputs we need. |
96 | TF_LITE_ENSURE_EQ(context, node->inputs->size, 12); |
97 | TF_LITE_ENSURE_EQ(context, node->outputs->size, |
98 | params->merge_outputs ? 1 : 2); |
99 | |
100 | const TfLiteTensor* input; |
101 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
102 | const TfLiteTensor* fw_input_weights; |
103 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor, |
104 | &fw_input_weights)); |
105 | const TfLiteTensor* fw_recurrent_weights; |
106 | TF_LITE_ENSURE_OK(context, |
107 | GetInputSafe(context, node, kFwRecurrentWeightsTensor, |
108 | &fw_recurrent_weights)); |
109 | const TfLiteTensor* fw_bias; |
110 | TF_LITE_ENSURE_OK(context, |
111 | GetInputSafe(context, node, kFwBiasTensor, &fw_bias)); |
112 | const TfLiteTensor* fw_hidden_state; |
113 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwHiddenStateTensor, |
114 | &fw_hidden_state)); |
115 | const TfLiteTensor* bw_input_weights; |
116 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor, |
117 | &bw_input_weights)); |
118 | const TfLiteTensor* bw_recurrent_weights; |
119 | TF_LITE_ENSURE_OK(context, |
120 | GetInputSafe(context, node, kBwRecurrentWeightsTensor, |
121 | &bw_recurrent_weights)); |
122 | const TfLiteTensor* bw_bias; |
123 | TF_LITE_ENSURE_OK(context, |
124 | GetInputSafe(context, node, kBwBiasTensor, &bw_bias)); |
125 | const TfLiteTensor* bw_hidden_state; |
126 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwHiddenStateTensor, |
127 | &bw_hidden_state)); |
128 | |
129 | const TfLiteTensor* aux_input = |
130 | GetOptionalInputTensor(context, node, kAuxInputTensor); |
131 | const TfLiteTensor* fw_aux_input_weights = |
132 | GetOptionalInputTensor(context, node, kFwAuxWeightsTensor); |
133 | const TfLiteTensor* bw_aux_input_weights = |
134 | GetOptionalInputTensor(context, node, kBwAuxWeightsTensor); |
135 | |
136 | const bool aux_inputs_weights_or_none = |
137 | ((fw_aux_input_weights != nullptr) && |
138 | (bw_aux_input_weights != nullptr)) || |
139 | ((fw_aux_input_weights == nullptr) && (bw_aux_input_weights == nullptr)); |
140 | TF_LITE_ENSURE(context, aux_inputs_weights_or_none); |
141 | const bool has_aux_input = (fw_aux_input_weights != nullptr); |
142 | |
143 | // Check all the parameters of tensor match within themselves and match the |
144 | // input configuration. |
145 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); |
146 | |
147 | TF_LITE_ENSURE_EQ(context, input->dims->size, 3); |
148 | const bool time_major = params->time_major; |
149 | const int batch_size = |
150 | (time_major) ? input->dims->data[1] : input->dims->data[0]; |
151 | const int max_time = |
152 | (time_major) ? input->dims->data[0] : input->dims->data[1]; |
153 | const int fw_num_units = fw_input_weights->dims->data[0]; |
154 | const int bw_num_units = bw_input_weights->dims->data[0]; |
155 | TF_LITE_ENSURE_EQ(context, input->dims->data[2], |
156 | fw_input_weights->dims->data[1]); |
157 | TF_LITE_ENSURE_EQ(context, input->dims->data[2], |
158 | bw_input_weights->dims->data[1]); |
159 | TF_LITE_ENSURE_EQ(context, fw_input_weights->dims->data[0], |
160 | fw_bias->dims->data[0]); |
161 | TF_LITE_ENSURE_EQ(context, bw_input_weights->dims->data[0], |
162 | bw_bias->dims->data[0]); |
163 | TF_LITE_ENSURE_EQ(context, fw_recurrent_weights->dims->data[0], |
164 | fw_bias->dims->data[0]); |
165 | TF_LITE_ENSURE_EQ(context, bw_recurrent_weights->dims->data[1], |
166 | bw_bias->dims->data[0]); |
167 | TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2); |
168 | TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size); |
169 | TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units); |
170 | TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2); |
171 | TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size); |
172 | TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units); |
173 | |
174 | if (has_aux_input) { |
175 | // Check that aux_input has the same dimensions (except last) as the input. |
176 | TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]); |
177 | TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]); |
178 | // Check that aux_input_weights has the same dimensions (except last) as |
179 | // the input_weights. |
180 | TF_LITE_ASSERT_EQ(fw_aux_input_weights->dims->data[0], fw_num_units); |
181 | TF_LITE_ASSERT_EQ(bw_aux_input_weights->dims->data[0], bw_num_units); |
182 | TF_LITE_ASSERT_EQ(aux_input->dims->data[2], |
183 | fw_aux_input_weights->dims->data[1]); |
184 | TF_LITE_ASSERT_EQ(aux_input->dims->data[2], |
185 | bw_aux_input_weights->dims->data[1]); |
186 | } |
187 | |
188 | if (IsHybridOp(input, fw_input_weights)) { |
189 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
190 | op_data->fw_compute_row_sums = true; |
191 | op_data->bw_compute_row_sums = true; |
192 | TfLiteIntArrayFree(node->temporaries); |
193 | if (has_aux_input) { |
194 | node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); |
195 | } else { |
196 | // No need to create a temporary tensor for the non-existent aux_input. |
197 | node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors - 1); |
198 | } |
199 | |
200 | node->temporaries->data[kInputQuantized] = |
201 | op_data->scratch_tensor_index + kInputQuantized; |
202 | TfLiteTensor* input_quantized; |
203 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized, |
204 | &input_quantized)); |
205 | input_quantized->type = fw_input_weights->type; |
206 | input_quantized->allocation_type = kTfLiteArenaRw; |
207 | if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { |
208 | TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); |
209 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, |
210 | input_quantized_size)); |
211 | } |
212 | |
213 | node->temporaries->data[kFwHiddenStateQuantized] = |
214 | op_data->scratch_tensor_index + kFwHiddenStateQuantized; |
215 | TfLiteTensor* fw_hidden_state_quantized; |
216 | TF_LITE_ENSURE_OK(context, |
217 | GetTemporarySafe(context, node, kFwHiddenStateQuantized, |
218 | &fw_hidden_state_quantized)); |
219 | fw_hidden_state_quantized->type = fw_input_weights->type; |
220 | fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw; |
221 | if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims, |
222 | fw_hidden_state->dims)) { |
223 | TfLiteIntArray* fw_hidden_state_quantized_size = |
224 | TfLiteIntArrayCopy(fw_hidden_state->dims); |
225 | TF_LITE_ENSURE_OK( |
226 | context, context->ResizeTensor(context, fw_hidden_state_quantized, |
227 | fw_hidden_state_quantized_size)); |
228 | } |
229 | |
230 | node->temporaries->data[kBwHiddenStateQuantized] = |
231 | op_data->scratch_tensor_index + kBwHiddenStateQuantized; |
232 | TfLiteTensor* bw_hidden_state_quantized; |
233 | TF_LITE_ENSURE_OK(context, |
234 | GetTemporarySafe(context, node, kBwHiddenStateQuantized, |
235 | &bw_hidden_state_quantized)); |
236 | bw_hidden_state_quantized->type = fw_input_weights->type; |
237 | bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw; |
238 | if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims, |
239 | bw_hidden_state->dims)) { |
240 | TfLiteIntArray* bw_hidden_state_quantized_size = |
241 | TfLiteIntArrayCopy(bw_hidden_state->dims); |
242 | TF_LITE_ENSURE_OK( |
243 | context, context->ResizeTensor(context, bw_hidden_state_quantized, |
244 | bw_hidden_state_quantized_size)); |
245 | } |
246 | |
247 | // Allocate temporary tensors to store scaling factors of quantization. |
248 | node->temporaries->data[kScalingFactors] = |
249 | op_data->scratch_tensor_index + kScalingFactors; |
250 | TfLiteTensor* scaling_factors; |
251 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScalingFactors, |
252 | &scaling_factors)); |
253 | scaling_factors->type = kTfLiteFloat32; |
254 | scaling_factors->allocation_type = kTfLiteArenaRw; |
255 | int scaling_dims[1] = {batch_size}; |
256 | if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { |
257 | TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); |
258 | scaling_factors_size->data[0] = batch_size; |
259 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, |
260 | scaling_factors_size)); |
261 | } |
262 | node->temporaries->data[kAccumScratch] = |
263 | op_data->scratch_tensor_index + kAccumScratch; |
264 | TfLiteTensor* accum_scratch; |
265 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch, |
266 | &accum_scratch)); |
267 | accum_scratch->type = kTfLiteInt32; |
268 | accum_scratch->allocation_type = kTfLiteArenaRw; |
269 | int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units), |
270 | batch_size}; |
271 | if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, |
272 | accum_scratch_dims)) { |
273 | TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2); |
274 | accum_scratch_size->data[0] = accum_scratch_dims[0]; |
275 | accum_scratch_size->data[1] = accum_scratch_dims[1]; |
276 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch, |
277 | accum_scratch_size)); |
278 | } |
279 | node->temporaries->data[kZeroPoints] = |
280 | op_data->scratch_tensor_index + kZeroPoints; |
281 | TfLiteTensor* zero_points; |
282 | TF_LITE_ENSURE_OK( |
283 | context, |
284 | GetTemporarySafe(context, node, /*index=*/kZeroPoints, &zero_points)); |
285 | zero_points->type = kTfLiteInt32; |
286 | zero_points->allocation_type = kTfLiteArenaRw; |
287 | int zero_points_dims[1] = {batch_size}; |
288 | if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { |
289 | TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); |
290 | zero_points_size->data[0] = batch_size; |
291 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, |
292 | zero_points_size)); |
293 | } |
294 | const int num_row_sums = has_aux_input ? 3 : 2; |
295 | node->temporaries->data[kFwRowSums] = |
296 | op_data->scratch_tensor_index + kFwRowSums; |
297 | TfLiteTensor* fw_row_sums; |
298 | TF_LITE_ENSURE_OK( |
299 | context, |
300 | GetTemporarySafe(context, node, /*index=*/kFwRowSums, &fw_row_sums)); |
301 | fw_row_sums->type = kTfLiteInt32; |
302 | fw_row_sums->name = "Lstm_fw_row_sums" ; |
303 | fw_row_sums->allocation_type = kTfLiteArenaRwPersistent; |
304 | int fw_row_sums_dims[2] = {num_row_sums, fw_num_units}; |
305 | if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) { |
306 | TfLiteIntArray* fw_row_sums_size = TfLiteIntArrayCreate(2); |
307 | fw_row_sums_size->data[0] = fw_row_sums_dims[0]; |
308 | fw_row_sums_size->data[1] = fw_row_sums_dims[1]; |
309 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums, |
310 | fw_row_sums_size)); |
311 | } |
312 | node->temporaries->data[kBwRowSums] = |
313 | op_data->scratch_tensor_index + kBwRowSums; |
314 | TfLiteTensor* bw_row_sums; |
315 | TF_LITE_ENSURE_OK( |
316 | context, |
317 | GetTemporarySafe(context, node, /*index=*/kBwRowSums, &bw_row_sums)); |
318 | bw_row_sums->type = kTfLiteInt32; |
319 | bw_row_sums->name = "Lstm_bw_row_sums" ; |
320 | bw_row_sums->allocation_type = kTfLiteArenaRwPersistent; |
321 | int bw_row_sums_dims[2] = {num_row_sums, bw_num_units}; |
322 | if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) { |
323 | TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2); |
324 | bw_row_sums_size->data[0] = bw_row_sums_dims[0]; |
325 | bw_row_sums_size->data[1] = bw_row_sums_dims[1]; |
326 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums, |
327 | bw_row_sums_size)); |
328 | } |
329 | if (has_aux_input) { |
330 | node->temporaries->data[kAuxInputQuantized] = |
331 | op_data->scratch_tensor_index + kAuxInputQuantized; |
332 | TfLiteTensor* aux_input_quantized; |
333 | TF_LITE_ENSURE_OK(context, |
334 | GetTemporarySafe(context, node, kAuxInputQuantized, |
335 | &aux_input_quantized)); |
336 | aux_input_quantized->type = fw_input_weights->type; |
337 | aux_input_quantized->allocation_type = kTfLiteArenaRw; |
338 | if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) { |
339 | TfLiteIntArray* aux_input_quantized_size = |
340 | TfLiteIntArrayCopy(aux_input->dims); |
341 | TF_LITE_ENSURE_OK(context, |
342 | context->ResizeTensor(context, aux_input_quantized, |
343 | aux_input_quantized_size)); |
344 | } |
345 | } |
346 | } |
347 | |
348 | // Resize outputs. |
349 | TfLiteTensor* fw_output; |
350 | TF_LITE_ENSURE_OK(context, |
351 | GetOutputSafe(context, node, kFwOutputTensor, &fw_output)); |
352 | TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3); |
353 | fw_output_size_array->data[0] = (time_major) ? max_time : batch_size; |
354 | fw_output_size_array->data[1] = (time_major) ? batch_size : max_time; |
355 | fw_output_size_array->data[2] = |
356 | params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; |
357 | TF_LITE_ENSURE_OK( |
358 | context, context->ResizeTensor(context, fw_output, fw_output_size_array)); |
359 | if (!params->merge_outputs) { |
360 | TfLiteTensor* bw_output; |
361 | TF_LITE_ENSURE_OK( |
362 | context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output)); |
363 | TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3); |
364 | bw_output_size_array->data[0] = batch_size; |
365 | bw_output_size_array->data[1] = max_time; |
366 | bw_output_size_array->data[2] = bw_num_units; |
367 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output, |
368 | bw_output_size_array)); |
369 | } |
370 | |
371 | return kTfLiteOk; |
372 | } |
373 | |
374 | TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* bw_input, |
375 | const TfLiteTensor* fw_input_weights, |
376 | const TfLiteTensor* fw_recurrent_weights, |
377 | const TfLiteTensor* fw_bias, |
378 | const TfLiteTensor* bw_input_weights, |
379 | const TfLiteTensor* bw_recurrent_weights, |
380 | const TfLiteTensor* bw_bias, |
381 | const TfLiteTensor* aux_input, |
382 | const TfLiteTensor* fw_aux_input_weights, |
383 | const TfLiteTensor* bw_aux_input_weights, |
384 | const TfLiteBidirectionalSequenceRNNParams* params, |
385 | TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, |
386 | TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { |
387 | const bool time_major = params->time_major; |
388 | const int batch_size = |
389 | (time_major) ? input->dims->data[1] : input->dims->data[0]; |
390 | const int max_time = |
391 | (time_major) ? input->dims->data[0] : input->dims->data[1]; |
392 | const int input_size = input->dims->data[2]; |
393 | const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0; |
394 | |
395 | const int fw_num_units = fw_input_weights->dims->data[0]; |
396 | const float* fw_bias_ptr = GetTensorData<float>(fw_bias); |
397 | const float* fw_input_weights_ptr = GetTensorData<float>(fw_input_weights); |
398 | const float* fw_recurrent_weights_ptr = |
399 | GetTensorData<float>(fw_recurrent_weights); |
400 | |
401 | const int bw_num_units = bw_input_weights->dims->data[0]; |
402 | const float* bw_bias_ptr = GetTensorData<float>(bw_bias); |
403 | const float* bw_input_weights_ptr = GetTensorData<float>(bw_input_weights); |
404 | const float* bw_recurrent_weights_ptr = |
405 | GetTensorData<float>(bw_recurrent_weights); |
406 | |
407 | const float* fw_aux_input_weights_ptr = |
408 | (fw_aux_input_weights != nullptr) |
409 | ? GetTensorData<float>(fw_aux_input_weights) |
410 | : nullptr; |
411 | const float* bw_aux_input_weights_ptr = |
412 | (bw_aux_input_weights != nullptr) |
413 | ? GetTensorData<float>(bw_aux_input_weights) |
414 | : nullptr; |
415 | |
416 | const int fw_output_step = |
417 | params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; |
418 | const int bw_output_step = |
419 | params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; |
420 | if (time_major) { |
421 | // Forward cell. |
422 | float* fw_hidden_state_ptr_batch = GetTensorData<float>(fw_hidden_state); |
423 | for (int s = 0; s < max_time; s++) { |
424 | const float* input_ptr_batch = |
425 | GetTensorData<float>(input) + s * input_size * batch_size; |
426 | const float* aux_input_ptr_batch = |
427 | (aux_input != nullptr) |
428 | ? GetTensorData<float>(aux_input) + s * input_size * batch_size |
429 | : nullptr; |
430 | float* output_ptr_batch = |
431 | GetTensorData<float>(fw_output) + s * fw_output_step * batch_size; |
432 | |
433 | kernel_utils::RnnBatchStep( |
434 | input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch, |
435 | fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr, |
436 | input_size, aux_input_size, fw_num_units, batch_size, fw_output_step, |
437 | params->activation, fw_hidden_state_ptr_batch, output_ptr_batch); |
438 | } |
439 | // Backward cell. |
440 | float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state); |
441 | for (int s = max_time - 1; s >= 0; s--) { |
442 | const float* input_ptr_batch = |
443 | GetTensorData<float>(bw_input) + s * input_size * batch_size; |
444 | const float* aux_input_ptr_batch = |
445 | (aux_input != nullptr) |
446 | ? GetTensorData<float>(aux_input) + s * input_size * batch_size |
447 | : nullptr; |
448 | float* output_ptr_batch = |
449 | (params->merge_outputs |
450 | ? GetTensorData<float>(fw_output) + fw_num_units |
451 | : GetTensorData<float>(bw_output)) + |
452 | s * bw_output_step * batch_size; |
453 | |
454 | kernel_utils::RnnBatchStep( |
455 | input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch, |
456 | bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr, |
457 | input_size, aux_input_size, bw_num_units, batch_size, bw_output_step, |
458 | params->activation, bw_hidden_state_ptr_batch, output_ptr_batch); |
459 | } |
460 | } else { |
461 | for (int b = 0; b < batch_size; b++) { |
462 | // Forward cell. |
463 | float* fw_hidden_state_ptr_batch = |
464 | GetTensorData<float>(fw_hidden_state) + b * fw_num_units; |
465 | float* fw_output_offset = |
466 | GetTensorData<float>(fw_output) + b * fw_output_step * max_time; |
467 | for (int s = 0; s < max_time; s++) { |
468 | const float* input_ptr_batch = GetTensorData<float>(input) + |
469 | b * input_size * max_time + |
470 | s * input_size; |
471 | const float* aux_input_ptr_batch = |
472 | (aux_input != nullptr) |
473 | ? GetTensorData<float>(aux_input) + |
474 | b * aux_input_size * max_time + s * aux_input_size |
475 | : nullptr; |
476 | float* output_ptr_batch = fw_output_offset + s * fw_output_step; |
477 | |
478 | kernel_utils::RnnBatchStep( |
479 | input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch, |
480 | fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr, |
481 | input_size, aux_input_size, fw_num_units, /*batch_size=*/1, |
482 | fw_output_step, params->activation, fw_hidden_state_ptr_batch, |
483 | output_ptr_batch); |
484 | } |
485 | // Backward cell. |
486 | float* bw_hidden_state_ptr_batch = |
487 | GetTensorData<float>(bw_hidden_state) + b * bw_num_units; |
488 | float* bw_output_offset = |
489 | params->merge_outputs |
490 | ? GetTensorData<float>(fw_output) + |
491 | b * bw_output_step * max_time + fw_num_units |
492 | : GetTensorData<float>(bw_output) + b * bw_output_step * max_time; |
493 | for (int s = max_time - 1; s >= 0; s--) { |
494 | const float* input_ptr_batch = GetTensorData<float>(input) + |
495 | b * input_size * max_time + |
496 | s * input_size; |
497 | const float* aux_input_ptr_batch = |
498 | (aux_input != nullptr) |
499 | ? GetTensorData<float>(aux_input) + |
500 | b * aux_input_size * max_time + s * aux_input_size |
501 | : nullptr; |
502 | float* output_ptr_batch = bw_output_offset + s * bw_output_step; |
503 | |
504 | kernel_utils::RnnBatchStep( |
505 | input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch, |
506 | bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr, |
507 | input_size, aux_input_size, bw_num_units, /*batch_size=*/1, |
508 | bw_output_step, params->activation, bw_hidden_state_ptr_batch, |
509 | output_ptr_batch); |
510 | } |
511 | } |
512 | } |
513 | return kTfLiteOk; |
514 | } |
515 | |
516 | TfLiteStatus EvalHybrid( |
517 | const TfLiteTensor* input, const TfLiteTensor* bw_input, |
518 | const TfLiteTensor* fw_input_weights, |
519 | const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias, |
520 | const TfLiteTensor* bw_input_weights, |
521 | const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias, |
522 | const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights, |
523 | const TfLiteTensor* aux_bw_input_weights, |
524 | const TfLiteBidirectionalSequenceRNNParams* params, |
525 | TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized, |
526 | TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized, |
527 | TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, |
528 | TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state, |
529 | TfLiteTensor* bw_output, TfLiteTensor* zero_points, |
530 | TfLiteTensor* accum_scratch, TfLiteTensor* fw_row_sums, |
531 | TfLiteTensor* bw_row_sums, bool* fw_compute_row_sums, |
532 | bool* bw_compute_row_sums) { |
533 | const bool time_major = params->time_major; |
534 | const int batch_size = |
535 | (time_major) ? input->dims->data[1] : input->dims->data[0]; |
536 | const int max_time = |
537 | (time_major) ? input->dims->data[0] : input->dims->data[1]; |
538 | const int input_size = input->dims->data[2]; |
539 | const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0; |
540 | |
541 | const int fw_num_units = fw_input_weights->dims->data[0]; |
542 | const float* fw_bias_ptr = GetTensorData<float>(fw_bias); |
543 | const int8_t* fw_input_weights_ptr = GetTensorData<int8_t>(fw_input_weights); |
544 | float fw_input_weights_scale = fw_input_weights->params.scale; |
545 | const int8_t* fw_recurrent_weights_ptr = |
546 | GetTensorData<int8_t>(fw_recurrent_weights); |
547 | float fw_recurrent_weights_scale = fw_recurrent_weights->params.scale; |
548 | |
549 | const int bw_num_units = bw_input_weights->dims->data[0]; |
550 | const float* bw_bias_ptr = GetTensorData<float>(bw_bias); |
551 | const int8_t* bw_input_weights_ptr = GetTensorData<int8_t>(bw_input_weights); |
552 | float bw_input_weights_scale = bw_input_weights->params.scale; |
553 | const int8_t* bw_recurrent_weights_ptr = |
554 | GetTensorData<int8_t>(bw_recurrent_weights); |
555 | float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale; |
556 | |
557 | // Set the auxiliary pointers and scales if needed. |
558 | const int8_t* aux_fw_input_weights_ptr = nullptr; |
559 | float aux_fw_input_weights_scale = 0.0f; |
560 | const int8_t* aux_bw_input_weights_ptr = nullptr; |
561 | float aux_bw_input_weights_scale = 0.0f; |
562 | int8_t* aux_quantized_input_ptr = nullptr; |
563 | if (aux_input_size > 0) { |
564 | aux_fw_input_weights_ptr = GetTensorData<int8_t>(aux_fw_input_weights); |
565 | aux_fw_input_weights_scale = aux_fw_input_weights->params.scale; |
566 | aux_bw_input_weights_ptr = GetTensorData<int8_t>(aux_bw_input_weights); |
567 | aux_bw_input_weights_scale = aux_bw_input_weights->params.scale; |
568 | aux_quantized_input_ptr = GetTensorData<int8_t>(aux_input_quantized); |
569 | } |
570 | |
571 | // Initialize temporary storage for quantized values. |
572 | int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_quantized); |
573 | int8_t* fw_quantized_hidden_state_ptr = |
574 | GetTensorData<int8_t>(fw_hidden_state_quantized); |
575 | int8_t* bw_quantized_hidden_state_ptr = |
576 | GetTensorData<int8_t>(bw_hidden_state_quantized); |
577 | float* scaling_factors_ptr = GetTensorData<float>(scaling_factors); |
578 | int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch); |
579 | int32_t* zero_points_ptr = nullptr; |
580 | int32_t* fw_row_sums_ptr = nullptr; |
581 | int32_t* bw_row_sums_ptr = nullptr; |
582 | if (params->asymmetric_quantize_inputs) { |
583 | zero_points_ptr = GetTensorData<int32_t>(zero_points); |
584 | fw_row_sums_ptr = GetTensorData<int32_t>(fw_row_sums); |
585 | bw_row_sums_ptr = GetTensorData<int32_t>(bw_row_sums); |
586 | } |
587 | const int fw_output_step = |
588 | params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; |
589 | const int bw_output_step = |
590 | params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; |
591 | |
592 | if (time_major) { |
593 | for (int t = 0; t < max_time; t++) { |
594 | // Forward cell. |
595 | float* fw_hidden_state_ptr_batch = GetTensorData<float>(fw_hidden_state); |
596 | for (int s = 0; s < max_time; s++) { |
597 | const float* input_ptr_batch = |
598 | GetTensorData<float>(input) + s * input_size * batch_size; |
599 | const float* aux_input_ptr_batch = |
600 | (aux_input != nullptr) |
601 | ? GetTensorData<float>(aux_input) + s * input_size * batch_size |
602 | : nullptr; |
603 | float* output_ptr_batch = |
604 | GetTensorData<float>(fw_output) + s * fw_output_step * batch_size; |
605 | |
606 | kernel_utils::RnnBatchStep( |
607 | input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, |
608 | aux_input_ptr_batch, aux_fw_input_weights_ptr, |
609 | aux_fw_input_weights_scale, fw_recurrent_weights_ptr, |
610 | fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size, |
611 | fw_num_units, batch_size, fw_output_step, params->activation, |
612 | quantized_input_ptr, aux_quantized_input_ptr, |
613 | fw_quantized_hidden_state_ptr, scaling_factors_ptr, |
614 | fw_hidden_state_ptr_batch, output_ptr_batch, |
615 | params->asymmetric_quantize_inputs, zero_points_ptr, |
616 | accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums); |
617 | } |
618 | // Backward cell. |
619 | float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state); |
620 | for (int s = max_time - 1; s >= 0; s--) { |
621 | const float* input_ptr_batch = |
622 | GetTensorData<float>(bw_input) + s * input_size * batch_size; |
623 | const float* aux_input_ptr_batch = |
624 | (aux_input != nullptr) |
625 | ? GetTensorData<float>(aux_input) + s * input_size * batch_size |
626 | : nullptr; |
627 | float* output_ptr_batch = |
628 | (params->merge_outputs |
629 | ? GetTensorData<float>(fw_output) + fw_num_units |
630 | : GetTensorData<float>(bw_output)) + |
631 | s * bw_output_step * batch_size; |
632 | |
633 | kernel_utils::RnnBatchStep( |
634 | input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, |
635 | aux_input_ptr_batch, aux_bw_input_weights_ptr, |
636 | aux_bw_input_weights_scale, bw_recurrent_weights_ptr, |
637 | bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size, |
638 | bw_num_units, batch_size, bw_output_step, params->activation, |
639 | quantized_input_ptr, aux_quantized_input_ptr, |
640 | bw_quantized_hidden_state_ptr, scaling_factors_ptr, |
641 | bw_hidden_state_ptr_batch, output_ptr_batch, |
642 | params->asymmetric_quantize_inputs, zero_points_ptr, |
643 | accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums); |
644 | } |
645 | } |
646 | } else { |
647 | for (int b = 0; b < batch_size; b++) { |
648 | // Forward cell. |
649 | float* fw_hidden_state_ptr_batch = |
650 | GetTensorData<float>(fw_hidden_state) + b * fw_num_units; |
651 | float* fw_output_offset = |
652 | GetTensorData<float>(fw_output) + b * fw_output_step * max_time; |
653 | for (int s = 0; s < max_time; s++) { |
654 | const float* input_ptr_batch = GetTensorData<float>(input) + |
655 | b * input_size * max_time + |
656 | s * input_size; |
657 | const float* aux_input_ptr_batch = |
658 | (aux_input != nullptr) |
659 | ? GetTensorData<float>(aux_input) + b * input_size * max_time + |
660 | s * input_size |
661 | : nullptr; |
662 | float* output_ptr_batch = fw_output_offset + s * fw_output_step; |
663 | |
664 | kernel_utils::RnnBatchStep( |
665 | input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, |
666 | aux_input_ptr_batch, aux_fw_input_weights_ptr, |
667 | aux_fw_input_weights_scale, fw_recurrent_weights_ptr, |
668 | fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size, |
669 | fw_num_units, /*batch_size=*/1, fw_output_step, params->activation, |
670 | quantized_input_ptr, aux_quantized_input_ptr, |
671 | fw_quantized_hidden_state_ptr, scaling_factors_ptr, |
672 | fw_hidden_state_ptr_batch, output_ptr_batch, |
673 | params->asymmetric_quantize_inputs, zero_points_ptr, |
674 | accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums); |
675 | } |
676 | // Backward cell. |
677 | float* bw_hidden_state_ptr_batch = |
678 | GetTensorData<float>(bw_hidden_state) + b * bw_num_units; |
679 | float* bw_output_offset = |
680 | params->merge_outputs |
681 | ? GetTensorData<float>(fw_output) + |
682 | b * bw_output_step * max_time + fw_num_units |
683 | : GetTensorData<float>(bw_output) + b * bw_output_step * max_time; |
684 | for (int s = max_time - 1; s >= 0; s--) { |
685 | const float* input_ptr_batch = GetTensorData<float>(input) + |
686 | b * input_size * max_time + |
687 | s * input_size; |
688 | const float* aux_input_ptr_batch = |
689 | (aux_input != nullptr) |
690 | ? GetTensorData<float>(aux_input) + b * input_size * max_time + |
691 | s * input_size |
692 | : nullptr; |
693 | float* output_ptr_batch = bw_output_offset + s * bw_output_step; |
694 | |
695 | kernel_utils::RnnBatchStep( |
696 | input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, |
697 | aux_input_ptr_batch, aux_bw_input_weights_ptr, |
698 | aux_bw_input_weights_scale, bw_recurrent_weights_ptr, |
699 | bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size, |
700 | bw_num_units, /*batch_size=*/1, bw_output_step, params->activation, |
701 | quantized_input_ptr, aux_quantized_input_ptr, |
702 | bw_quantized_hidden_state_ptr, scaling_factors_ptr, |
703 | bw_hidden_state_ptr_batch, output_ptr_batch, |
704 | params->asymmetric_quantize_inputs, zero_points_ptr, |
705 | accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums); |
706 | } |
707 | } |
708 | } |
709 | return kTfLiteOk; |
710 | } |
711 | |
712 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
713 | const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>( |
714 | node->builtin_data); |
715 | |
716 | const TfLiteTensor* input; |
717 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
718 | const TfLiteTensor* fw_input_weights; |
719 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor, |
720 | &fw_input_weights)); |
721 | const TfLiteTensor* fw_recurrent_weights; |
722 | TF_LITE_ENSURE_OK(context, |
723 | GetInputSafe(context, node, kFwRecurrentWeightsTensor, |
724 | &fw_recurrent_weights)); |
725 | const TfLiteTensor* fw_bias; |
726 | TF_LITE_ENSURE_OK(context, |
727 | GetInputSafe(context, node, kFwBiasTensor, &fw_bias)); |
728 | const TfLiteTensor* bw_input_weights; |
729 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor, |
730 | &bw_input_weights)); |
731 | const TfLiteTensor* bw_recurrent_weights; |
732 | TF_LITE_ENSURE_OK(context, |
733 | GetInputSafe(context, node, kBwRecurrentWeightsTensor, |
734 | &bw_recurrent_weights)); |
735 | const TfLiteTensor* bw_bias; |
736 | TF_LITE_ENSURE_OK(context, |
737 | GetInputSafe(context, node, kBwBiasTensor, &bw_bias)); |
738 | |
739 | // Get auxiliary inputs. |
740 | const TfLiteTensor* aux_input = |
741 | GetOptionalInputTensor(context, node, kAuxInputTensor); |
742 | const TfLiteTensor* fw_aux_input_weights = |
743 | GetOptionalInputTensor(context, node, kFwAuxWeightsTensor); |
744 | const TfLiteTensor* bw_aux_input_weights = |
745 | GetOptionalInputTensor(context, node, kBwAuxWeightsTensor); |
746 | |
747 | TfLiteTensor* fw_hidden_state = |
748 | GetVariableInput(context, node, kFwHiddenStateTensor); |
749 | TFLITE_DCHECK(fw_hidden_state != nullptr); |
750 | TfLiteTensor* bw_hidden_state = |
751 | GetVariableInput(context, node, kBwHiddenStateTensor); |
752 | TFLITE_DCHECK(bw_hidden_state != nullptr); |
753 | |
754 | TfLiteTensor* fw_output; |
755 | TF_LITE_ENSURE_OK(context, |
756 | GetOutputSafe(context, node, kFwOutputTensor, &fw_output)); |
757 | TfLiteTensor* bw_output = params->merge_outputs |
758 | ? nullptr |
759 | : GetOutput(context, node, kBwOutputTensor); |
760 | |
761 | const bool has_previous_bw_output = (aux_input != nullptr); |
762 | const bool use_aux_input = (fw_aux_input_weights != nullptr); |
763 | |
764 | // We want to cover the following cases: |
765 | // |
766 | // If not stacking (not connected after other bidi lstms): |
767 | // both fw & bw will just use `input`; aux_input will be null. |
768 | // |
769 | // If stacking with cross_links, TensorFlow equivalent |
770 | // (tf.contrib.rnn.stack_bidirectional_rnn): |
771 | // both fw & bw will use `input`, but aux_input will be none null. |
772 | // Note, this time, whether connected after other bidi lstms both works. |
773 | // |
774 | // If stacking without cross_links, but connected after other bidi lstms, |
775 | // TensorFlow equivalent (tf.nn.static_bidirectional_rnn): |
776 | // fw will use `input`, bw will use aux_input, and the `real aux_input` |
777 | // will be null. |
778 | |
779 | const bool non_stacking_mode = !use_aux_input && has_previous_bw_output; |
780 | const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input; |
781 | const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input; |
782 | |
783 | switch (fw_input_weights->type) { |
784 | case kTfLiteFloat32: |
785 | return EvalFloat(input, bw_input, fw_input_weights, fw_recurrent_weights, |
786 | fw_bias, bw_input_weights, bw_recurrent_weights, bw_bias, |
787 | real_aux_input, fw_aux_input_weights, |
788 | bw_aux_input_weights, params, fw_hidden_state, fw_output, |
789 | bw_hidden_state, bw_output); |
790 | case kTfLiteUInt8: |
791 | case kTfLiteInt8: { |
792 | TfLiteTensor* input_quantized; |
793 | TF_LITE_ENSURE_OK( |
794 | context, |
795 | GetTemporarySafe(context, node, kInputQuantized, &input_quantized)); |
796 | TfLiteTensor* fw_hidden_state_quantized; |
797 | TF_LITE_ENSURE_OK(context, |
798 | GetTemporarySafe(context, node, kFwHiddenStateQuantized, |
799 | &fw_hidden_state_quantized)); |
800 | TfLiteTensor* bw_hidden_state_quantized; |
801 | TF_LITE_ENSURE_OK(context, |
802 | GetTemporarySafe(context, node, kBwHiddenStateQuantized, |
803 | &bw_hidden_state_quantized)); |
804 | TfLiteTensor* scaling_factors; |
805 | TF_LITE_ENSURE_OK( |
806 | context, |
807 | GetTemporarySafe(context, node, kScalingFactors, &scaling_factors)); |
808 | TfLiteTensor* zero_points; |
809 | TF_LITE_ENSURE_OK( |
810 | context, GetTemporarySafe(context, node, kZeroPoints, &zero_points)); |
811 | TfLiteTensor* accum_scratch; |
812 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch, |
813 | &accum_scratch)); |
814 | TfLiteTensor* fw_row_sums; |
815 | TF_LITE_ENSURE_OK( |
816 | context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums)); |
817 | TfLiteTensor* bw_row_sums; |
818 | TF_LITE_ENSURE_OK( |
819 | context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums)); |
820 | TfLiteTensor* aux_input_quantized = |
821 | use_aux_input ? GetTemporary(context, node, kAuxInputQuantized) |
822 | : nullptr; |
823 | auto* op_data = reinterpret_cast<OpData*>(node->user_data); |
824 | return EvalHybrid( |
825 | input, bw_input, fw_input_weights, fw_recurrent_weights, fw_bias, |
826 | bw_input_weights, bw_recurrent_weights, bw_bias, real_aux_input, |
827 | fw_aux_input_weights, bw_aux_input_weights, params, scaling_factors, |
828 | input_quantized, aux_input_quantized, fw_hidden_state_quantized, |
829 | fw_hidden_state, fw_output, bw_hidden_state_quantized, |
830 | bw_hidden_state, bw_output, zero_points, accum_scratch, fw_row_sums, |
831 | bw_row_sums, &op_data->fw_compute_row_sums, |
832 | &op_data->bw_compute_row_sums); |
833 | } |
834 | default: |
835 | TF_LITE_KERNEL_LOG(context, "Type not currently supported." ); |
836 | return kTfLiteError; |
837 | } |
838 | } |
839 | |
840 | } // namespace bidirectional_sequence_rnn |
841 | |
842 | TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() { |
843 | static TfLiteRegistration r = { |
844 | bidirectional_sequence_rnn::Init, bidirectional_sequence_rnn::Free, |
845 | bidirectional_sequence_rnn::Prepare, bidirectional_sequence_rnn::Eval}; |
846 | return &r; |
847 | } |
848 | |
849 | } // namespace builtin |
850 | } // namespace ops |
851 | } // namespace tflite |
852 | |