1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <algorithm>
16#include <cstddef>
17#include <cstdint>
18
19#include "tensorflow/lite/c/builtin_op_data.h"
20#include "tensorflow/lite/c/common.h"
21#include "tensorflow/lite/kernels/internal/kernel_utils.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/kernel_util.h"
24#include "tensorflow/lite/kernels/op_macros.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace bidirectional_sequence_rnn {
30
31namespace {
32
33struct OpData {
34 int scratch_tensor_index;
35 bool fw_compute_row_sums = false;
36 bool bw_compute_row_sums = false;
37};
38
39} // namespace
40
41// LINT.IfChange
42
43constexpr int kInputTensor = 0;
44// Forward and backward cell tensors.
45constexpr int kFwWeightsTensor = 1;
46constexpr int kFwRecurrentWeightsTensor = 2;
47constexpr int kFwBiasTensor = 3;
48constexpr int kFwHiddenStateTensor = 4;
49constexpr int kBwWeightsTensor = 5;
50constexpr int kBwRecurrentWeightsTensor = 6;
51constexpr int kBwBiasTensor = 7;
52constexpr int kBwHiddenStateTensor = 8;
53// Used as auxiliary input and weights when stacking for
54// tf.contrib.rnn.stack_bidirectional_rnn case (with cross links); Used as input
55// to the backward cell when stacking for tf.nn.static_bidirectional_rnn case
56// (without cross links).
57constexpr int kAuxInputTensor = 9; // Optional.
58constexpr int kFwAuxWeightsTensor = 10; // Optional.
59constexpr int kBwAuxWeightsTensor = 11; // Optional.
60// Output tensors.
61constexpr int kFwOutputTensor = 0;
62constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false.
63
64// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc)
65
66// Temporary tensors.
67enum TemporaryTensor {
68 kInputQuantized = 0,
69 kFwHiddenStateQuantized = 1,
70 kBwHiddenStateQuantized = 2,
71 kScalingFactors = 3,
72 kAccumScratch = 4,
73 kZeroPoints = 5,
74 kFwRowSums = 6,
75 kBwRowSums = 7,
76 kAuxInputQuantized = 8,
77 kNumTemporaryTensors = 9
78};
79
80void* Init(TfLiteContext* context, const char* buffer, size_t length) {
81 auto* op_data = new OpData();
82 context->AddTensors(context, kNumTemporaryTensors,
83 &op_data->scratch_tensor_index);
84 return op_data;
85}
86
87void Free(TfLiteContext* context, void* buffer) {
88 delete reinterpret_cast<OpData*>(buffer);
89}
90
91TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
92 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
93 node->builtin_data);
94
95 // Check we have all the inputs and outputs we need.
96 TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
97 TF_LITE_ENSURE_EQ(context, node->outputs->size,
98 params->merge_outputs ? 1 : 2);
99
100 const TfLiteTensor* input;
101 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
102 const TfLiteTensor* fw_input_weights;
103 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
104 &fw_input_weights));
105 const TfLiteTensor* fw_recurrent_weights;
106 TF_LITE_ENSURE_OK(context,
107 GetInputSafe(context, node, kFwRecurrentWeightsTensor,
108 &fw_recurrent_weights));
109 const TfLiteTensor* fw_bias;
110 TF_LITE_ENSURE_OK(context,
111 GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
112 const TfLiteTensor* fw_hidden_state;
113 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwHiddenStateTensor,
114 &fw_hidden_state));
115 const TfLiteTensor* bw_input_weights;
116 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
117 &bw_input_weights));
118 const TfLiteTensor* bw_recurrent_weights;
119 TF_LITE_ENSURE_OK(context,
120 GetInputSafe(context, node, kBwRecurrentWeightsTensor,
121 &bw_recurrent_weights));
122 const TfLiteTensor* bw_bias;
123 TF_LITE_ENSURE_OK(context,
124 GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
125 const TfLiteTensor* bw_hidden_state;
126 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwHiddenStateTensor,
127 &bw_hidden_state));
128
129 const TfLiteTensor* aux_input =
130 GetOptionalInputTensor(context, node, kAuxInputTensor);
131 const TfLiteTensor* fw_aux_input_weights =
132 GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
133 const TfLiteTensor* bw_aux_input_weights =
134 GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
135
136 const bool aux_inputs_weights_or_none =
137 ((fw_aux_input_weights != nullptr) &&
138 (bw_aux_input_weights != nullptr)) ||
139 ((fw_aux_input_weights == nullptr) && (bw_aux_input_weights == nullptr));
140 TF_LITE_ENSURE(context, aux_inputs_weights_or_none);
141 const bool has_aux_input = (fw_aux_input_weights != nullptr);
142
143 // Check all the parameters of tensor match within themselves and match the
144 // input configuration.
145 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
146
147 TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
148 const bool time_major = params->time_major;
149 const int batch_size =
150 (time_major) ? input->dims->data[1] : input->dims->data[0];
151 const int max_time =
152 (time_major) ? input->dims->data[0] : input->dims->data[1];
153 const int fw_num_units = fw_input_weights->dims->data[0];
154 const int bw_num_units = bw_input_weights->dims->data[0];
155 TF_LITE_ENSURE_EQ(context, input->dims->data[2],
156 fw_input_weights->dims->data[1]);
157 TF_LITE_ENSURE_EQ(context, input->dims->data[2],
158 bw_input_weights->dims->data[1]);
159 TF_LITE_ENSURE_EQ(context, fw_input_weights->dims->data[0],
160 fw_bias->dims->data[0]);
161 TF_LITE_ENSURE_EQ(context, bw_input_weights->dims->data[0],
162 bw_bias->dims->data[0]);
163 TF_LITE_ENSURE_EQ(context, fw_recurrent_weights->dims->data[0],
164 fw_bias->dims->data[0]);
165 TF_LITE_ENSURE_EQ(context, bw_recurrent_weights->dims->data[1],
166 bw_bias->dims->data[0]);
167 TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2);
168 TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size);
169 TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units);
170 TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2);
171 TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size);
172 TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units);
173
174 if (has_aux_input) {
175 // Check that aux_input has the same dimensions (except last) as the input.
176 TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
177 TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
178 // Check that aux_input_weights has the same dimensions (except last) as
179 // the input_weights.
180 TF_LITE_ASSERT_EQ(fw_aux_input_weights->dims->data[0], fw_num_units);
181 TF_LITE_ASSERT_EQ(bw_aux_input_weights->dims->data[0], bw_num_units);
182 TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
183 fw_aux_input_weights->dims->data[1]);
184 TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
185 bw_aux_input_weights->dims->data[1]);
186 }
187
188 if (IsHybridOp(input, fw_input_weights)) {
189 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
190 op_data->fw_compute_row_sums = true;
191 op_data->bw_compute_row_sums = true;
192 TfLiteIntArrayFree(node->temporaries);
193 if (has_aux_input) {
194 node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
195 } else {
196 // No need to create a temporary tensor for the non-existent aux_input.
197 node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors - 1);
198 }
199
200 node->temporaries->data[kInputQuantized] =
201 op_data->scratch_tensor_index + kInputQuantized;
202 TfLiteTensor* input_quantized;
203 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized,
204 &input_quantized));
205 input_quantized->type = fw_input_weights->type;
206 input_quantized->allocation_type = kTfLiteArenaRw;
207 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
208 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
209 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
210 input_quantized_size));
211 }
212
213 node->temporaries->data[kFwHiddenStateQuantized] =
214 op_data->scratch_tensor_index + kFwHiddenStateQuantized;
215 TfLiteTensor* fw_hidden_state_quantized;
216 TF_LITE_ENSURE_OK(context,
217 GetTemporarySafe(context, node, kFwHiddenStateQuantized,
218 &fw_hidden_state_quantized));
219 fw_hidden_state_quantized->type = fw_input_weights->type;
220 fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
221 if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims,
222 fw_hidden_state->dims)) {
223 TfLiteIntArray* fw_hidden_state_quantized_size =
224 TfLiteIntArrayCopy(fw_hidden_state->dims);
225 TF_LITE_ENSURE_OK(
226 context, context->ResizeTensor(context, fw_hidden_state_quantized,
227 fw_hidden_state_quantized_size));
228 }
229
230 node->temporaries->data[kBwHiddenStateQuantized] =
231 op_data->scratch_tensor_index + kBwHiddenStateQuantized;
232 TfLiteTensor* bw_hidden_state_quantized;
233 TF_LITE_ENSURE_OK(context,
234 GetTemporarySafe(context, node, kBwHiddenStateQuantized,
235 &bw_hidden_state_quantized));
236 bw_hidden_state_quantized->type = fw_input_weights->type;
237 bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
238 if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims,
239 bw_hidden_state->dims)) {
240 TfLiteIntArray* bw_hidden_state_quantized_size =
241 TfLiteIntArrayCopy(bw_hidden_state->dims);
242 TF_LITE_ENSURE_OK(
243 context, context->ResizeTensor(context, bw_hidden_state_quantized,
244 bw_hidden_state_quantized_size));
245 }
246
247 // Allocate temporary tensors to store scaling factors of quantization.
248 node->temporaries->data[kScalingFactors] =
249 op_data->scratch_tensor_index + kScalingFactors;
250 TfLiteTensor* scaling_factors;
251 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScalingFactors,
252 &scaling_factors));
253 scaling_factors->type = kTfLiteFloat32;
254 scaling_factors->allocation_type = kTfLiteArenaRw;
255 int scaling_dims[1] = {batch_size};
256 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
257 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
258 scaling_factors_size->data[0] = batch_size;
259 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
260 scaling_factors_size));
261 }
262 node->temporaries->data[kAccumScratch] =
263 op_data->scratch_tensor_index + kAccumScratch;
264 TfLiteTensor* accum_scratch;
265 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
266 &accum_scratch));
267 accum_scratch->type = kTfLiteInt32;
268 accum_scratch->allocation_type = kTfLiteArenaRw;
269 int accum_scratch_dims[2] = {std::max(fw_num_units, bw_num_units),
270 batch_size};
271 if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
272 accum_scratch_dims)) {
273 TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
274 accum_scratch_size->data[0] = accum_scratch_dims[0];
275 accum_scratch_size->data[1] = accum_scratch_dims[1];
276 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
277 accum_scratch_size));
278 }
279 node->temporaries->data[kZeroPoints] =
280 op_data->scratch_tensor_index + kZeroPoints;
281 TfLiteTensor* zero_points;
282 TF_LITE_ENSURE_OK(
283 context,
284 GetTemporarySafe(context, node, /*index=*/kZeroPoints, &zero_points));
285 zero_points->type = kTfLiteInt32;
286 zero_points->allocation_type = kTfLiteArenaRw;
287 int zero_points_dims[1] = {batch_size};
288 if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
289 TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
290 zero_points_size->data[0] = batch_size;
291 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
292 zero_points_size));
293 }
294 const int num_row_sums = has_aux_input ? 3 : 2;
295 node->temporaries->data[kFwRowSums] =
296 op_data->scratch_tensor_index + kFwRowSums;
297 TfLiteTensor* fw_row_sums;
298 TF_LITE_ENSURE_OK(
299 context,
300 GetTemporarySafe(context, node, /*index=*/kFwRowSums, &fw_row_sums));
301 fw_row_sums->type = kTfLiteInt32;
302 fw_row_sums->name = "Lstm_fw_row_sums";
303 fw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
304 int fw_row_sums_dims[2] = {num_row_sums, fw_num_units};
305 if (!TfLiteIntArrayEqualsArray(fw_row_sums->dims, 2, fw_row_sums_dims)) {
306 TfLiteIntArray* fw_row_sums_size = TfLiteIntArrayCreate(2);
307 fw_row_sums_size->data[0] = fw_row_sums_dims[0];
308 fw_row_sums_size->data[1] = fw_row_sums_dims[1];
309 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_row_sums,
310 fw_row_sums_size));
311 }
312 node->temporaries->data[kBwRowSums] =
313 op_data->scratch_tensor_index + kBwRowSums;
314 TfLiteTensor* bw_row_sums;
315 TF_LITE_ENSURE_OK(
316 context,
317 GetTemporarySafe(context, node, /*index=*/kBwRowSums, &bw_row_sums));
318 bw_row_sums->type = kTfLiteInt32;
319 bw_row_sums->name = "Lstm_bw_row_sums";
320 bw_row_sums->allocation_type = kTfLiteArenaRwPersistent;
321 int bw_row_sums_dims[2] = {num_row_sums, bw_num_units};
322 if (!TfLiteIntArrayEqualsArray(bw_row_sums->dims, 2, bw_row_sums_dims)) {
323 TfLiteIntArray* bw_row_sums_size = TfLiteIntArrayCreate(2);
324 bw_row_sums_size->data[0] = bw_row_sums_dims[0];
325 bw_row_sums_size->data[1] = bw_row_sums_dims[1];
326 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_row_sums,
327 bw_row_sums_size));
328 }
329 if (has_aux_input) {
330 node->temporaries->data[kAuxInputQuantized] =
331 op_data->scratch_tensor_index + kAuxInputQuantized;
332 TfLiteTensor* aux_input_quantized;
333 TF_LITE_ENSURE_OK(context,
334 GetTemporarySafe(context, node, kAuxInputQuantized,
335 &aux_input_quantized));
336 aux_input_quantized->type = fw_input_weights->type;
337 aux_input_quantized->allocation_type = kTfLiteArenaRw;
338 if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
339 TfLiteIntArray* aux_input_quantized_size =
340 TfLiteIntArrayCopy(aux_input->dims);
341 TF_LITE_ENSURE_OK(context,
342 context->ResizeTensor(context, aux_input_quantized,
343 aux_input_quantized_size));
344 }
345 }
346 }
347
348 // Resize outputs.
349 TfLiteTensor* fw_output;
350 TF_LITE_ENSURE_OK(context,
351 GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
352 TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
353 fw_output_size_array->data[0] = (time_major) ? max_time : batch_size;
354 fw_output_size_array->data[1] = (time_major) ? batch_size : max_time;
355 fw_output_size_array->data[2] =
356 params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
357 TF_LITE_ENSURE_OK(
358 context, context->ResizeTensor(context, fw_output, fw_output_size_array));
359 if (!params->merge_outputs) {
360 TfLiteTensor* bw_output;
361 TF_LITE_ENSURE_OK(
362 context, GetOutputSafe(context, node, kBwOutputTensor, &bw_output));
363 TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
364 bw_output_size_array->data[0] = batch_size;
365 bw_output_size_array->data[1] = max_time;
366 bw_output_size_array->data[2] = bw_num_units;
367 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output,
368 bw_output_size_array));
369 }
370
371 return kTfLiteOk;
372}
373
374TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* bw_input,
375 const TfLiteTensor* fw_input_weights,
376 const TfLiteTensor* fw_recurrent_weights,
377 const TfLiteTensor* fw_bias,
378 const TfLiteTensor* bw_input_weights,
379 const TfLiteTensor* bw_recurrent_weights,
380 const TfLiteTensor* bw_bias,
381 const TfLiteTensor* aux_input,
382 const TfLiteTensor* fw_aux_input_weights,
383 const TfLiteTensor* bw_aux_input_weights,
384 const TfLiteBidirectionalSequenceRNNParams* params,
385 TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
386 TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
387 const bool time_major = params->time_major;
388 const int batch_size =
389 (time_major) ? input->dims->data[1] : input->dims->data[0];
390 const int max_time =
391 (time_major) ? input->dims->data[0] : input->dims->data[1];
392 const int input_size = input->dims->data[2];
393 const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
394
395 const int fw_num_units = fw_input_weights->dims->data[0];
396 const float* fw_bias_ptr = GetTensorData<float>(fw_bias);
397 const float* fw_input_weights_ptr = GetTensorData<float>(fw_input_weights);
398 const float* fw_recurrent_weights_ptr =
399 GetTensorData<float>(fw_recurrent_weights);
400
401 const int bw_num_units = bw_input_weights->dims->data[0];
402 const float* bw_bias_ptr = GetTensorData<float>(bw_bias);
403 const float* bw_input_weights_ptr = GetTensorData<float>(bw_input_weights);
404 const float* bw_recurrent_weights_ptr =
405 GetTensorData<float>(bw_recurrent_weights);
406
407 const float* fw_aux_input_weights_ptr =
408 (fw_aux_input_weights != nullptr)
409 ? GetTensorData<float>(fw_aux_input_weights)
410 : nullptr;
411 const float* bw_aux_input_weights_ptr =
412 (bw_aux_input_weights != nullptr)
413 ? GetTensorData<float>(bw_aux_input_weights)
414 : nullptr;
415
416 const int fw_output_step =
417 params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
418 const int bw_output_step =
419 params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
420 if (time_major) {
421 // Forward cell.
422 float* fw_hidden_state_ptr_batch = GetTensorData<float>(fw_hidden_state);
423 for (int s = 0; s < max_time; s++) {
424 const float* input_ptr_batch =
425 GetTensorData<float>(input) + s * input_size * batch_size;
426 const float* aux_input_ptr_batch =
427 (aux_input != nullptr)
428 ? GetTensorData<float>(aux_input) + s * input_size * batch_size
429 : nullptr;
430 float* output_ptr_batch =
431 GetTensorData<float>(fw_output) + s * fw_output_step * batch_size;
432
433 kernel_utils::RnnBatchStep(
434 input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
435 fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
436 input_size, aux_input_size, fw_num_units, batch_size, fw_output_step,
437 params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
438 }
439 // Backward cell.
440 float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
441 for (int s = max_time - 1; s >= 0; s--) {
442 const float* input_ptr_batch =
443 GetTensorData<float>(bw_input) + s * input_size * batch_size;
444 const float* aux_input_ptr_batch =
445 (aux_input != nullptr)
446 ? GetTensorData<float>(aux_input) + s * input_size * batch_size
447 : nullptr;
448 float* output_ptr_batch =
449 (params->merge_outputs
450 ? GetTensorData<float>(fw_output) + fw_num_units
451 : GetTensorData<float>(bw_output)) +
452 s * bw_output_step * batch_size;
453
454 kernel_utils::RnnBatchStep(
455 input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
456 bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
457 input_size, aux_input_size, bw_num_units, batch_size, bw_output_step,
458 params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
459 }
460 } else {
461 for (int b = 0; b < batch_size; b++) {
462 // Forward cell.
463 float* fw_hidden_state_ptr_batch =
464 GetTensorData<float>(fw_hidden_state) + b * fw_num_units;
465 float* fw_output_offset =
466 GetTensorData<float>(fw_output) + b * fw_output_step * max_time;
467 for (int s = 0; s < max_time; s++) {
468 const float* input_ptr_batch = GetTensorData<float>(input) +
469 b * input_size * max_time +
470 s * input_size;
471 const float* aux_input_ptr_batch =
472 (aux_input != nullptr)
473 ? GetTensorData<float>(aux_input) +
474 b * aux_input_size * max_time + s * aux_input_size
475 : nullptr;
476 float* output_ptr_batch = fw_output_offset + s * fw_output_step;
477
478 kernel_utils::RnnBatchStep(
479 input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
480 fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
481 input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
482 fw_output_step, params->activation, fw_hidden_state_ptr_batch,
483 output_ptr_batch);
484 }
485 // Backward cell.
486 float* bw_hidden_state_ptr_batch =
487 GetTensorData<float>(bw_hidden_state) + b * bw_num_units;
488 float* bw_output_offset =
489 params->merge_outputs
490 ? GetTensorData<float>(fw_output) +
491 b * bw_output_step * max_time + fw_num_units
492 : GetTensorData<float>(bw_output) + b * bw_output_step * max_time;
493 for (int s = max_time - 1; s >= 0; s--) {
494 const float* input_ptr_batch = GetTensorData<float>(input) +
495 b * input_size * max_time +
496 s * input_size;
497 const float* aux_input_ptr_batch =
498 (aux_input != nullptr)
499 ? GetTensorData<float>(aux_input) +
500 b * aux_input_size * max_time + s * aux_input_size
501 : nullptr;
502 float* output_ptr_batch = bw_output_offset + s * bw_output_step;
503
504 kernel_utils::RnnBatchStep(
505 input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
506 bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
507 input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
508 bw_output_step, params->activation, bw_hidden_state_ptr_batch,
509 output_ptr_batch);
510 }
511 }
512 }
513 return kTfLiteOk;
514}
515
516TfLiteStatus EvalHybrid(
517 const TfLiteTensor* input, const TfLiteTensor* bw_input,
518 const TfLiteTensor* fw_input_weights,
519 const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
520 const TfLiteTensor* bw_input_weights,
521 const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
522 const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
523 const TfLiteTensor* aux_bw_input_weights,
524 const TfLiteBidirectionalSequenceRNNParams* params,
525 TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized,
526 TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
527 TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
528 TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
529 TfLiteTensor* bw_output, TfLiteTensor* zero_points,
530 TfLiteTensor* accum_scratch, TfLiteTensor* fw_row_sums,
531 TfLiteTensor* bw_row_sums, bool* fw_compute_row_sums,
532 bool* bw_compute_row_sums) {
533 const bool time_major = params->time_major;
534 const int batch_size =
535 (time_major) ? input->dims->data[1] : input->dims->data[0];
536 const int max_time =
537 (time_major) ? input->dims->data[0] : input->dims->data[1];
538 const int input_size = input->dims->data[2];
539 const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
540
541 const int fw_num_units = fw_input_weights->dims->data[0];
542 const float* fw_bias_ptr = GetTensorData<float>(fw_bias);
543 const int8_t* fw_input_weights_ptr = GetTensorData<int8_t>(fw_input_weights);
544 float fw_input_weights_scale = fw_input_weights->params.scale;
545 const int8_t* fw_recurrent_weights_ptr =
546 GetTensorData<int8_t>(fw_recurrent_weights);
547 float fw_recurrent_weights_scale = fw_recurrent_weights->params.scale;
548
549 const int bw_num_units = bw_input_weights->dims->data[0];
550 const float* bw_bias_ptr = GetTensorData<float>(bw_bias);
551 const int8_t* bw_input_weights_ptr = GetTensorData<int8_t>(bw_input_weights);
552 float bw_input_weights_scale = bw_input_weights->params.scale;
553 const int8_t* bw_recurrent_weights_ptr =
554 GetTensorData<int8_t>(bw_recurrent_weights);
555 float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale;
556
557 // Set the auxiliary pointers and scales if needed.
558 const int8_t* aux_fw_input_weights_ptr = nullptr;
559 float aux_fw_input_weights_scale = 0.0f;
560 const int8_t* aux_bw_input_weights_ptr = nullptr;
561 float aux_bw_input_weights_scale = 0.0f;
562 int8_t* aux_quantized_input_ptr = nullptr;
563 if (aux_input_size > 0) {
564 aux_fw_input_weights_ptr = GetTensorData<int8_t>(aux_fw_input_weights);
565 aux_fw_input_weights_scale = aux_fw_input_weights->params.scale;
566 aux_bw_input_weights_ptr = GetTensorData<int8_t>(aux_bw_input_weights);
567 aux_bw_input_weights_scale = aux_bw_input_weights->params.scale;
568 aux_quantized_input_ptr = GetTensorData<int8_t>(aux_input_quantized);
569 }
570
571 // Initialize temporary storage for quantized values.
572 int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_quantized);
573 int8_t* fw_quantized_hidden_state_ptr =
574 GetTensorData<int8_t>(fw_hidden_state_quantized);
575 int8_t* bw_quantized_hidden_state_ptr =
576 GetTensorData<int8_t>(bw_hidden_state_quantized);
577 float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
578 int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
579 int32_t* zero_points_ptr = nullptr;
580 int32_t* fw_row_sums_ptr = nullptr;
581 int32_t* bw_row_sums_ptr = nullptr;
582 if (params->asymmetric_quantize_inputs) {
583 zero_points_ptr = GetTensorData<int32_t>(zero_points);
584 fw_row_sums_ptr = GetTensorData<int32_t>(fw_row_sums);
585 bw_row_sums_ptr = GetTensorData<int32_t>(bw_row_sums);
586 }
587 const int fw_output_step =
588 params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
589 const int bw_output_step =
590 params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
591
592 if (time_major) {
593 for (int t = 0; t < max_time; t++) {
594 // Forward cell.
595 float* fw_hidden_state_ptr_batch = GetTensorData<float>(fw_hidden_state);
596 for (int s = 0; s < max_time; s++) {
597 const float* input_ptr_batch =
598 GetTensorData<float>(input) + s * input_size * batch_size;
599 const float* aux_input_ptr_batch =
600 (aux_input != nullptr)
601 ? GetTensorData<float>(aux_input) + s * input_size * batch_size
602 : nullptr;
603 float* output_ptr_batch =
604 GetTensorData<float>(fw_output) + s * fw_output_step * batch_size;
605
606 kernel_utils::RnnBatchStep(
607 input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
608 aux_input_ptr_batch, aux_fw_input_weights_ptr,
609 aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
610 fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
611 fw_num_units, batch_size, fw_output_step, params->activation,
612 quantized_input_ptr, aux_quantized_input_ptr,
613 fw_quantized_hidden_state_ptr, scaling_factors_ptr,
614 fw_hidden_state_ptr_batch, output_ptr_batch,
615 params->asymmetric_quantize_inputs, zero_points_ptr,
616 accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
617 }
618 // Backward cell.
619 float* bw_hidden_state_ptr_batch = GetTensorData<float>(bw_hidden_state);
620 for (int s = max_time - 1; s >= 0; s--) {
621 const float* input_ptr_batch =
622 GetTensorData<float>(bw_input) + s * input_size * batch_size;
623 const float* aux_input_ptr_batch =
624 (aux_input != nullptr)
625 ? GetTensorData<float>(aux_input) + s * input_size * batch_size
626 : nullptr;
627 float* output_ptr_batch =
628 (params->merge_outputs
629 ? GetTensorData<float>(fw_output) + fw_num_units
630 : GetTensorData<float>(bw_output)) +
631 s * bw_output_step * batch_size;
632
633 kernel_utils::RnnBatchStep(
634 input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
635 aux_input_ptr_batch, aux_bw_input_weights_ptr,
636 aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
637 bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
638 bw_num_units, batch_size, bw_output_step, params->activation,
639 quantized_input_ptr, aux_quantized_input_ptr,
640 bw_quantized_hidden_state_ptr, scaling_factors_ptr,
641 bw_hidden_state_ptr_batch, output_ptr_batch,
642 params->asymmetric_quantize_inputs, zero_points_ptr,
643 accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
644 }
645 }
646 } else {
647 for (int b = 0; b < batch_size; b++) {
648 // Forward cell.
649 float* fw_hidden_state_ptr_batch =
650 GetTensorData<float>(fw_hidden_state) + b * fw_num_units;
651 float* fw_output_offset =
652 GetTensorData<float>(fw_output) + b * fw_output_step * max_time;
653 for (int s = 0; s < max_time; s++) {
654 const float* input_ptr_batch = GetTensorData<float>(input) +
655 b * input_size * max_time +
656 s * input_size;
657 const float* aux_input_ptr_batch =
658 (aux_input != nullptr)
659 ? GetTensorData<float>(aux_input) + b * input_size * max_time +
660 s * input_size
661 : nullptr;
662 float* output_ptr_batch = fw_output_offset + s * fw_output_step;
663
664 kernel_utils::RnnBatchStep(
665 input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
666 aux_input_ptr_batch, aux_fw_input_weights_ptr,
667 aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
668 fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
669 fw_num_units, /*batch_size=*/1, fw_output_step, params->activation,
670 quantized_input_ptr, aux_quantized_input_ptr,
671 fw_quantized_hidden_state_ptr, scaling_factors_ptr,
672 fw_hidden_state_ptr_batch, output_ptr_batch,
673 params->asymmetric_quantize_inputs, zero_points_ptr,
674 accum_scratch_ptr, fw_row_sums_ptr, fw_compute_row_sums);
675 }
676 // Backward cell.
677 float* bw_hidden_state_ptr_batch =
678 GetTensorData<float>(bw_hidden_state) + b * bw_num_units;
679 float* bw_output_offset =
680 params->merge_outputs
681 ? GetTensorData<float>(fw_output) +
682 b * bw_output_step * max_time + fw_num_units
683 : GetTensorData<float>(bw_output) + b * bw_output_step * max_time;
684 for (int s = max_time - 1; s >= 0; s--) {
685 const float* input_ptr_batch = GetTensorData<float>(input) +
686 b * input_size * max_time +
687 s * input_size;
688 const float* aux_input_ptr_batch =
689 (aux_input != nullptr)
690 ? GetTensorData<float>(aux_input) + b * input_size * max_time +
691 s * input_size
692 : nullptr;
693 float* output_ptr_batch = bw_output_offset + s * bw_output_step;
694
695 kernel_utils::RnnBatchStep(
696 input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
697 aux_input_ptr_batch, aux_bw_input_weights_ptr,
698 aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
699 bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
700 bw_num_units, /*batch_size=*/1, bw_output_step, params->activation,
701 quantized_input_ptr, aux_quantized_input_ptr,
702 bw_quantized_hidden_state_ptr, scaling_factors_ptr,
703 bw_hidden_state_ptr_batch, output_ptr_batch,
704 params->asymmetric_quantize_inputs, zero_points_ptr,
705 accum_scratch_ptr, bw_row_sums_ptr, bw_compute_row_sums);
706 }
707 }
708 }
709 return kTfLiteOk;
710}
711
712TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
713 const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
714 node->builtin_data);
715
716 const TfLiteTensor* input;
717 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
718 const TfLiteTensor* fw_input_weights;
719 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kFwWeightsTensor,
720 &fw_input_weights));
721 const TfLiteTensor* fw_recurrent_weights;
722 TF_LITE_ENSURE_OK(context,
723 GetInputSafe(context, node, kFwRecurrentWeightsTensor,
724 &fw_recurrent_weights));
725 const TfLiteTensor* fw_bias;
726 TF_LITE_ENSURE_OK(context,
727 GetInputSafe(context, node, kFwBiasTensor, &fw_bias));
728 const TfLiteTensor* bw_input_weights;
729 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBwWeightsTensor,
730 &bw_input_weights));
731 const TfLiteTensor* bw_recurrent_weights;
732 TF_LITE_ENSURE_OK(context,
733 GetInputSafe(context, node, kBwRecurrentWeightsTensor,
734 &bw_recurrent_weights));
735 const TfLiteTensor* bw_bias;
736 TF_LITE_ENSURE_OK(context,
737 GetInputSafe(context, node, kBwBiasTensor, &bw_bias));
738
739 // Get auxiliary inputs.
740 const TfLiteTensor* aux_input =
741 GetOptionalInputTensor(context, node, kAuxInputTensor);
742 const TfLiteTensor* fw_aux_input_weights =
743 GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
744 const TfLiteTensor* bw_aux_input_weights =
745 GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
746
747 TfLiteTensor* fw_hidden_state =
748 GetVariableInput(context, node, kFwHiddenStateTensor);
749 TFLITE_DCHECK(fw_hidden_state != nullptr);
750 TfLiteTensor* bw_hidden_state =
751 GetVariableInput(context, node, kBwHiddenStateTensor);
752 TFLITE_DCHECK(bw_hidden_state != nullptr);
753
754 TfLiteTensor* fw_output;
755 TF_LITE_ENSURE_OK(context,
756 GetOutputSafe(context, node, kFwOutputTensor, &fw_output));
757 TfLiteTensor* bw_output = params->merge_outputs
758 ? nullptr
759 : GetOutput(context, node, kBwOutputTensor);
760
761 const bool has_previous_bw_output = (aux_input != nullptr);
762 const bool use_aux_input = (fw_aux_input_weights != nullptr);
763
764 // We want to cover the following cases:
765 //
766 // If not stacking (not connected after other bidi lstms):
767 // both fw & bw will just use `input`; aux_input will be null.
768 //
769 // If stacking with cross_links, TensorFlow equivalent
770 // (tf.contrib.rnn.stack_bidirectional_rnn):
771 // both fw & bw will use `input`, but aux_input will be none null.
772 // Note, this time, whether connected after other bidi lstms both works.
773 //
774 // If stacking without cross_links, but connected after other bidi lstms,
775 // TensorFlow equivalent (tf.nn.static_bidirectional_rnn):
776 // fw will use `input`, bw will use aux_input, and the `real aux_input`
777 // will be null.
778
779 const bool non_stacking_mode = !use_aux_input && has_previous_bw_output;
780 const TfLiteTensor* bw_input = non_stacking_mode ? aux_input : input;
781 const TfLiteTensor* real_aux_input = non_stacking_mode ? nullptr : aux_input;
782
783 switch (fw_input_weights->type) {
784 case kTfLiteFloat32:
785 return EvalFloat(input, bw_input, fw_input_weights, fw_recurrent_weights,
786 fw_bias, bw_input_weights, bw_recurrent_weights, bw_bias,
787 real_aux_input, fw_aux_input_weights,
788 bw_aux_input_weights, params, fw_hidden_state, fw_output,
789 bw_hidden_state, bw_output);
790 case kTfLiteUInt8:
791 case kTfLiteInt8: {
792 TfLiteTensor* input_quantized;
793 TF_LITE_ENSURE_OK(
794 context,
795 GetTemporarySafe(context, node, kInputQuantized, &input_quantized));
796 TfLiteTensor* fw_hidden_state_quantized;
797 TF_LITE_ENSURE_OK(context,
798 GetTemporarySafe(context, node, kFwHiddenStateQuantized,
799 &fw_hidden_state_quantized));
800 TfLiteTensor* bw_hidden_state_quantized;
801 TF_LITE_ENSURE_OK(context,
802 GetTemporarySafe(context, node, kBwHiddenStateQuantized,
803 &bw_hidden_state_quantized));
804 TfLiteTensor* scaling_factors;
805 TF_LITE_ENSURE_OK(
806 context,
807 GetTemporarySafe(context, node, kScalingFactors, &scaling_factors));
808 TfLiteTensor* zero_points;
809 TF_LITE_ENSURE_OK(
810 context, GetTemporarySafe(context, node, kZeroPoints, &zero_points));
811 TfLiteTensor* accum_scratch;
812 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch,
813 &accum_scratch));
814 TfLiteTensor* fw_row_sums;
815 TF_LITE_ENSURE_OK(
816 context, GetTemporarySafe(context, node, kFwRowSums, &fw_row_sums));
817 TfLiteTensor* bw_row_sums;
818 TF_LITE_ENSURE_OK(
819 context, GetTemporarySafe(context, node, kBwRowSums, &bw_row_sums));
820 TfLiteTensor* aux_input_quantized =
821 use_aux_input ? GetTemporary(context, node, kAuxInputQuantized)
822 : nullptr;
823 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
824 return EvalHybrid(
825 input, bw_input, fw_input_weights, fw_recurrent_weights, fw_bias,
826 bw_input_weights, bw_recurrent_weights, bw_bias, real_aux_input,
827 fw_aux_input_weights, bw_aux_input_weights, params, scaling_factors,
828 input_quantized, aux_input_quantized, fw_hidden_state_quantized,
829 fw_hidden_state, fw_output, bw_hidden_state_quantized,
830 bw_hidden_state, bw_output, zero_points, accum_scratch, fw_row_sums,
831 bw_row_sums, &op_data->fw_compute_row_sums,
832 &op_data->bw_compute_row_sums);
833 }
834 default:
835 TF_LITE_KERNEL_LOG(context, "Type not currently supported.");
836 return kTfLiteError;
837 }
838}
839
840} // namespace bidirectional_sequence_rnn
841
842TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() {
843 static TfLiteRegistration r = {
844 bidirectional_sequence_rnn::Init, bidirectional_sequence_rnn::Free,
845 bidirectional_sequence_rnn::Prepare, bidirectional_sequence_rnn::Eval};
846 return &r;
847}
848
849} // namespace builtin
850} // namespace ops
851} // namespace tflite
852