1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <cstddef>
16#include <cstdint>
17
18#include "tensorflow/lite/c/builtin_op_data.h"
19#include "tensorflow/lite/c/common.h"
20#include "tensorflow/lite/kernels/internal/kernel_utils.h"
21#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22#include "tensorflow/lite/kernels/kernel_util.h"
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace unidirectional_sequence_rnn {
28
29namespace {
30
31struct OpData {
32 int scratch_tensor_index;
33 bool compute_row_sums = false;
34};
35
36} // namespace
37
38// Input tensors.
39constexpr int kInputTensor = 0;
40constexpr int kWeightsTensor = 1;
41constexpr int kRecurrentWeightsTensor = 2;
42constexpr int kBiasTensor = 3;
43constexpr int kHiddenStateTensor = 4;
44
45// Output tensor.
46constexpr int kOutputTensor = 0;
47
48void* Init(TfLiteContext* context, const char* buffer, size_t length) {
49 auto* op_data = new OpData();
50 context->AddTensors(context, /*tensors_to_add=*/6,
51 &op_data->scratch_tensor_index);
52 return op_data;
53}
54
55void Free(TfLiteContext* context, void* buffer) {
56 delete reinterpret_cast<OpData*>(buffer);
57}
58
59TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
60 // Check we have all the inputs and outputs we need.
61 TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
62 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
63
64 const TfLiteTensor* input;
65 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
66 const TfLiteTensor* input_weights;
67 TF_LITE_ENSURE_OK(
68 context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
69 const TfLiteTensor* recurrent_weights;
70 TF_LITE_ENSURE_OK(
71 context,
72 GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
73 const TfLiteTensor* bias;
74 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
75 const TfLiteTensor* hidden_state;
76 TF_LITE_ENSURE_OK(
77 context, GetInputSafe(context, node, kHiddenStateTensor, &hidden_state));
78
79 // Check all the parameters of tensor match within themselves and match the
80 // input configuration.
81 auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
82 const bool time_major = params->time_major;
83 const int batch_size =
84 (time_major) ? input->dims->data[1] : input->dims->data[0];
85 const int max_time =
86 (time_major) ? input->dims->data[0] : input->dims->data[1];
87 const int num_units = input_weights->dims->data[0];
88 TF_LITE_ENSURE_EQ(context, input->dims->data[2],
89 input_weights->dims->data[1]);
90 TF_LITE_ENSURE_EQ(context, input_weights->dims->data[0], bias->dims->data[0]);
91 TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[0],
92 bias->dims->data[0]);
93 TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[1],
94 bias->dims->data[0]);
95 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
96 TF_LITE_ENSURE_TYPES_EQ(context, input_weights->type,
97 recurrent_weights->type);
98 TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
99 TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
100 TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
101
102 TfLiteTensor* output;
103 TF_LITE_ENSURE_OK(context,
104 GetOutputSafe(context, node, kOutputTensor, &output));
105
106 // Resize output.
107 TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3);
108 output_size_array->data[0] = (time_major) ? max_time : batch_size;
109 output_size_array->data[1] = (time_major) ? batch_size : max_time;
110 output_size_array->data[2] = num_units;
111 TF_LITE_ENSURE_OK(context,
112 context->ResizeTensor(context, output, output_size_array));
113
114 const bool is_hybrid = IsHybridOp(input, input_weights);
115
116 // Allocate temporary tensors to store quantized values of input and
117 // hidden_state tensors.
118 if (is_hybrid) {
119 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
120 op_data->compute_row_sums = true;
121 TfLiteIntArrayFree(node->temporaries);
122 node->temporaries = TfLiteIntArrayCreate(6);
123 node->temporaries->data[0] = op_data->scratch_tensor_index;
124 TfLiteTensor* input_quantized;
125 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
126 &input_quantized));
127 input_quantized->type = input_weights->type;
128 input_quantized->allocation_type = kTfLiteArenaRw;
129 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
130 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
131 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
132 input_quantized_size));
133 }
134 node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
135 TfLiteTensor* hidden_state_quantized;
136 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
137 &hidden_state_quantized));
138 hidden_state_quantized->type = input_weights->type;
139 hidden_state_quantized->allocation_type = kTfLiteArenaRw;
140 if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
141 hidden_state->dims)) {
142 TfLiteIntArray* hidden_state_quantized_size =
143 TfLiteIntArrayCopy(hidden_state->dims);
144 TF_LITE_ENSURE_OK(context,
145 context->ResizeTensor(context, hidden_state_quantized,
146 hidden_state_quantized_size));
147 }
148 node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
149 TfLiteTensor* scaling_factors;
150 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
151 &scaling_factors));
152 scaling_factors->type = kTfLiteFloat32;
153 scaling_factors->allocation_type = kTfLiteArenaRw;
154 int scaling_dims[1] = {batch_size};
155 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
156 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
157 scaling_factors_size->data[0] = batch_size;
158 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
159 scaling_factors_size));
160 }
161 node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
162 TfLiteTensor* accum_scratch;
163 TF_LITE_ENSURE_OK(
164 context, GetTemporarySafe(context, node, /*index=*/3, &accum_scratch));
165 accum_scratch->type = kTfLiteInt32;
166 accum_scratch->allocation_type = kTfLiteArenaRw;
167 int accum_scratch_dims[2] = {num_units, batch_size};
168 if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
169 accum_scratch_dims)) {
170 TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
171 accum_scratch_size->data[0] = accum_scratch_dims[0];
172 accum_scratch_size->data[1] = accum_scratch_dims[1];
173 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
174 accum_scratch_size));
175 }
176 node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
177 TfLiteTensor* zero_points;
178 TF_LITE_ENSURE_OK(
179 context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
180 zero_points->type = kTfLiteInt32;
181 zero_points->allocation_type = kTfLiteArenaRw;
182 int zero_points_dims[1] = {batch_size};
183 if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
184 TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
185 zero_points_size->data[0] = batch_size;
186 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
187 zero_points_size));
188 }
189 node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
190 TfLiteTensor* row_sums;
191 TF_LITE_ENSURE_OK(context,
192 GetTemporarySafe(context, node, /*index=*/5, &row_sums));
193 row_sums->type = kTfLiteInt32;
194 row_sums->allocation_type = kTfLiteArenaRwPersistent;
195 int row_sums_dims[2] = {2, num_units};
196 if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
197 TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
198 row_sums_size->data[0] = row_sums_dims[0];
199 row_sums_size->data[1] = row_sums_dims[1];
200 TF_LITE_ENSURE_OK(
201 context, context->ResizeTensor(context, row_sums, row_sums_size));
202 }
203 }
204 return kTfLiteOk;
205}
206
207TfLiteStatus EvalFloat(const TfLiteTensor* input,
208 const TfLiteTensor* input_weights,
209 const TfLiteTensor* recurrent_weights,
210 const TfLiteTensor* bias,
211 const TfLiteSequenceRNNParams* params,
212 TfLiteTensor* hidden_state, TfLiteTensor* output) {
213 // Initialize the pointer bias.
214 const float* bias_ptr = GetTensorData<float>(bias);
215
216 const bool time_major = params->time_major;
217 const int batch_size =
218 (time_major) ? input->dims->data[1] : input->dims->data[0];
219 const int max_time =
220 (time_major) ? input->dims->data[0] : input->dims->data[1];
221 const int num_units = input_weights->dims->data[0];
222 const int input_size = input->dims->data[2];
223
224 // Initialize input_weights and recurrent_weights.
225 const float* input_weights_ptr = GetTensorData<float>(input_weights);
226 const float* recurrent_weights_ptr = GetTensorData<float>(recurrent_weights);
227
228 if (time_major) {
229 // Initialize the pointer to hidden state.
230 float* hidden_state_ptr_batch = GetTensorData<float>(hidden_state);
231 // Unroll the sequence and use batch operations for efficiency.
232 for (int s = 0; s < max_time; s++) {
233 // Initialize the pointer to input and output.
234 const float* input_ptr_batch =
235 GetTensorData<float>(input) + s * input_size * batch_size;
236 float* output_ptr_batch =
237 GetTensorData<float>(output) + s * num_units * batch_size;
238
239 kernel_utils::RnnBatchStep(
240 input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr,
241 input_size, num_units, batch_size, num_units, params->activation,
242 hidden_state_ptr_batch, output_ptr_batch);
243 }
244 } else {
245 // For each batch
246 for (int b = 0; b < batch_size; b++) {
247 // Initialize the pointer to hidden state.
248 float* hidden_state_ptr_batch =
249 GetTensorData<float>(hidden_state) + b * num_units;
250 for (int s = 0; s < max_time; s++) {
251 // Initialize the pointer to input and output.
252 const float* input_ptr_batch = GetTensorData<float>(input) +
253 b * input_size * max_time +
254 s * input_size;
255 float* output_ptr_batch = GetTensorData<float>(output) +
256 b * num_units * max_time + s * num_units;
257
258 kernel_utils::RnnBatchStep(
259 input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr,
260 input_size, num_units, /*batch_size=*/1, num_units,
261 params->activation, hidden_state_ptr_batch, output_ptr_batch);
262 }
263 }
264 }
265 return kTfLiteOk;
266}
267
268TfLiteStatus EvalHybrid(
269 const TfLiteTensor* input, const TfLiteTensor* input_weights,
270 const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias,
271 const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch,
272 TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors,
273 TfLiteTensor* hidden_state, TfLiteTensor* output, TfLiteTensor* zero_points,
274 TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
275 bool* compute_row_sums) {
276 const bool time_major = params->time_major;
277 const int batch_size =
278 (time_major) ? input->dims->data[1] : input->dims->data[0];
279 const int max_time =
280 (time_major) ? input->dims->data[0] : input->dims->data[1];
281 const int num_units = input_weights->dims->data[0];
282 const int input_size = input->dims->data[2];
283
284 // Initialize the pointer bias.
285 const float* bias_ptr = GetTensorData<float>(bias);
286
287 // Initialize input_weights, recurrent_weights, and temporary storage for
288 // quantized values.
289 const int8_t* input_weights_ptr = GetTensorData<int8_t>(input_weights);
290 const int8_t* recurrent_weights_ptr =
291 GetTensorData<int8_t>(recurrent_weights);
292 int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_scratch);
293 int8_t* quantized_hidden_state_ptr =
294 GetTensorData<int8_t>(hidden_state_scratch);
295
296 // Get the scale of the quantized weights.
297 float input_weights_scale = input_weights->params.scale;
298 float recurrent_weights_scale = recurrent_weights->params.scale;
299 float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
300 int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
301 int32_t* zero_points_ptr = nullptr;
302 int32_t* row_sums_ptr = nullptr;
303
304 if (params->asymmetric_quantize_inputs) {
305 zero_points_ptr = GetTensorData<int32_t>(zero_points);
306 row_sums_ptr = GetTensorData<int32_t>(row_sums);
307 }
308
309 if (time_major) {
310 // Initialize the pointer to hidden state.
311 float* hidden_state_ptr_batch = GetTensorData<float>(hidden_state);
312 // Unroll the sequence and use batch operations for efficiency.
313 for (int s = 0; s < max_time; s++) {
314 // Initialize the pointer to input and output.
315 const float* input_ptr_batch =
316 GetTensorData<float>(input) + s * input_size * batch_size;
317 float* output_ptr_batch =
318 GetTensorData<float>(output) + s * num_units * batch_size;
319
320 kernel_utils::RnnBatchStep(
321 input_ptr_batch, input_weights_ptr, input_weights_scale,
322 recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
323 num_units, batch_size, num_units, params->activation,
324 quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
325 hidden_state_ptr_batch, output_ptr_batch,
326 params->asymmetric_quantize_inputs, zero_points_ptr,
327 accum_scratch_ptr, row_sums_ptr, compute_row_sums);
328 }
329 } else {
330 // For each batch
331 for (int b = 0; b < batch_size; b++) {
332 // Initialize the pointer to hidden state.
333 float* hidden_state_ptr_batch =
334 GetTensorData<float>(hidden_state) + b * num_units;
335 for (int s = 0; s < max_time; s++) {
336 // Initialize the pointer to input and output.
337 const float* input_ptr_batch = GetTensorData<float>(input) +
338 b * input_size * max_time +
339 s * input_size;
340 float* output_ptr_batch = GetTensorData<float>(output) +
341 b * num_units * max_time + s * num_units;
342 kernel_utils::RnnBatchStep(
343 input_ptr_batch, input_weights_ptr, input_weights_scale,
344 recurrent_weights_ptr, recurrent_weights_scale, bias_ptr,
345 input_size, num_units, /*batch_size=*/1, num_units,
346 params->activation, quantized_input_ptr, quantized_hidden_state_ptr,
347 scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch,
348 params->asymmetric_quantize_inputs, zero_points_ptr,
349 accum_scratch_ptr, row_sums_ptr, compute_row_sums);
350 }
351 }
352 }
353 return kTfLiteOk;
354}
355
356TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
357 auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
358 const TfLiteTensor* input;
359 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
360 const TfLiteTensor* input_weights;
361 TF_LITE_ENSURE_OK(
362 context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
363 const TfLiteTensor* recurrent_weights;
364 TF_LITE_ENSURE_OK(
365 context,
366 GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
367 const TfLiteTensor* bias;
368 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
369 // The hidden_state is a variable input tensor that can be modified.
370 TfLiteTensor* hidden_state =
371 GetVariableInput(context, node, kHiddenStateTensor);
372 TF_LITE_ENSURE(context, hidden_state != nullptr);
373 TfLiteTensor* output;
374 TF_LITE_ENSURE_OK(context,
375 GetOutputSafe(context, node, kOutputTensor, &output));
376
377 switch (input_weights->type) {
378 case kTfLiteFloat32:
379 return EvalFloat(input, input_weights, recurrent_weights, bias, params,
380 hidden_state, output);
381 case kTfLiteUInt8:
382 case kTfLiteInt8: {
383 // TODO(mirkov): implement eval with quantized inputs as well.
384 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
385 TfLiteTensor* input_quantized;
386 TF_LITE_ENSURE_OK(context,
387 GetTemporarySafe(context, node, 0, &input_quantized));
388 TfLiteTensor* hidden_state_quantized;
389 TF_LITE_ENSURE_OK(
390 context, GetTemporarySafe(context, node, 1, &hidden_state_quantized));
391 TfLiteTensor* scaling_factors;
392 TF_LITE_ENSURE_OK(context,
393 GetTemporarySafe(context, node, 2, &scaling_factors));
394 TfLiteTensor* accum_scratch;
395 TF_LITE_ENSURE_OK(context,
396 GetTemporarySafe(context, node, 3, &accum_scratch));
397 TfLiteTensor* zero_points;
398 TF_LITE_ENSURE_OK(context,
399 GetTemporarySafe(context, node, 4, &zero_points));
400 TfLiteTensor* row_sums;
401 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &row_sums));
402 return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
403 input_quantized, hidden_state_quantized,
404 scaling_factors, hidden_state, output, zero_points,
405 accum_scratch, row_sums, &op_data->compute_row_sums);
406 }
407 default:
408 TF_LITE_KERNEL_LOG(context, "Type %d not currently supported.",
409 TfLiteTypeGetName(input_weights->type));
410 return kTfLiteError;
411 }
412}
413
414} // namespace unidirectional_sequence_rnn
415
416TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN() {
417 static TfLiteRegistration r = {
418 unidirectional_sequence_rnn::Init, unidirectional_sequence_rnn::Free,
419 unidirectional_sequence_rnn::Prepare, unidirectional_sequence_rnn::Eval};
420 return &r;
421}
422
423} // namespace builtin
424} // namespace ops
425} // namespace tflite
426