1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <cstddef>
16#include <cstdint>
17
18#include "tensorflow/lite/c/builtin_op_data.h"
19#include "tensorflow/lite/c/common.h"
20#include "tensorflow/lite/kernels/internal/kernel_utils.h"
21#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22#include "tensorflow/lite/kernels/kernel_util.h"
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace rnn {
28
29namespace {
30
31struct OpData {
32 int scratch_tensor_index;
33 bool compute_row_sums = false;
34};
35
36} // namespace
37
38constexpr int kInputTensor = 0;
39constexpr int kWeightsTensor = 1;
40constexpr int kRecurrentWeightsTensor = 2;
41constexpr int kBiasTensor = 3;
42constexpr int kHiddenStateTensor = 4;
43
44// Output tensor.
45constexpr int kOutputTensor = 0;
46
47void* Init(TfLiteContext* context, const char* buffer, size_t length) {
48 auto* op_data = new OpData();
49 context->AddTensors(context, /*tensors_to_add=*/6,
50 &op_data->scratch_tensor_index);
51 return op_data;
52}
53
54void Free(TfLiteContext* context, void* buffer) {
55 delete reinterpret_cast<OpData*>(buffer);
56}
57
58TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
59 // Check we have all the inputs and outputs we need.
60 TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
61 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
62
63 const TfLiteTensor* input;
64 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
65 const TfLiteTensor* input_weights;
66 TF_LITE_ENSURE_OK(
67 context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
68 const TfLiteTensor* recurrent_weights;
69 TF_LITE_ENSURE_OK(
70 context,
71 GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
72 const TfLiteTensor* bias;
73 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
74 const TfLiteTensor* hidden_state;
75 TF_LITE_ENSURE_OK(
76 context, GetInputSafe(context, node, kHiddenStateTensor, &hidden_state));
77
78 // Check all the parameters of tensor match within themselves and match the
79 // input configuration.
80 const int batch_size = input->dims->data[0];
81 const int num_units = input_weights->dims->data[0];
82 TF_LITE_ENSURE_EQ(context, input->dims->data[1],
83 input_weights->dims->data[1]);
84 TF_LITE_ENSURE_EQ(context, input_weights->dims->data[0], bias->dims->data[0]);
85 TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[0],
86 bias->dims->data[0]);
87 TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[1],
88 bias->dims->data[0]);
89 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
90 TF_LITE_ENSURE_TYPES_EQ(context, input_weights->type,
91 recurrent_weights->type);
92 TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
93 TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
94 TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
95
96 TfLiteTensor* output;
97 TF_LITE_ENSURE_OK(context,
98 GetOutputSafe(context, node, kOutputTensor, &output));
99
100 // Resize output.
101 TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
102 output_size_array->data[0] = batch_size;
103 output_size_array->data[1] = num_units;
104 TF_LITE_ENSURE_OK(context,
105 context->ResizeTensor(context, output, output_size_array));
106
107 const bool is_hybrid = IsHybridOp(input, input_weights);
108
109 // Allocate temporary tensors to store quantized values of input and
110 // hidden_state tensors.
111 if (is_hybrid) {
112 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
113 op_data->compute_row_sums = true;
114 TfLiteIntArrayFree(node->temporaries);
115 node->temporaries = TfLiteIntArrayCreate(6);
116 node->temporaries->data[0] = op_data->scratch_tensor_index;
117 TfLiteTensor* input_quantized;
118 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
119 &input_quantized));
120 input_quantized->type = input_weights->type;
121 input_quantized->allocation_type = kTfLiteArenaRw;
122 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
123 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
124 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
125 input_quantized_size));
126 }
127 node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
128 TfLiteTensor* hidden_state_quantized;
129 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
130 &hidden_state_quantized));
131 hidden_state_quantized->type = input_weights->type;
132 hidden_state_quantized->allocation_type = kTfLiteArenaRw;
133 if (!TfLiteIntArrayEqual(hidden_state_quantized->dims,
134 hidden_state->dims)) {
135 TfLiteIntArray* hidden_state_quantized_size =
136 TfLiteIntArrayCopy(hidden_state->dims);
137 TF_LITE_ENSURE_OK(context,
138 context->ResizeTensor(context, hidden_state_quantized,
139 hidden_state_quantized_size));
140 }
141 node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
142 TfLiteTensor* scaling_factors;
143 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
144 &scaling_factors));
145 scaling_factors->type = kTfLiteFloat32;
146 scaling_factors->allocation_type = kTfLiteArenaRw;
147 int scaling_dims[1] = {batch_size};
148 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
149 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
150 scaling_factors_size->data[0] = batch_size;
151 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
152 scaling_factors_size));
153 }
154 node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
155 TfLiteTensor* accum_scratch;
156 TF_LITE_ENSURE_OK(
157 context, GetTemporarySafe(context, node, /*index=*/3, &accum_scratch));
158 accum_scratch->type = kTfLiteInt32;
159 accum_scratch->allocation_type = kTfLiteArenaRw;
160 int accum_scratch_dims[2] = {num_units, batch_size};
161 if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
162 accum_scratch_dims)) {
163 TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2);
164 accum_scratch_size->data[0] = accum_scratch_dims[0];
165 accum_scratch_size->data[1] = accum_scratch_dims[1];
166 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch,
167 accum_scratch_size));
168 }
169 node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
170 TfLiteTensor* zero_points;
171 TF_LITE_ENSURE_OK(
172 context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
173 zero_points->type = kTfLiteInt32;
174 zero_points->allocation_type = kTfLiteArenaRw;
175 int zero_points_dims[1] = {batch_size};
176 if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
177 TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
178 zero_points_size->data[0] = batch_size;
179 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
180 zero_points_size));
181 }
182 node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
183 TfLiteTensor* row_sums;
184 TF_LITE_ENSURE_OK(context,
185 GetTemporarySafe(context, node, /*index=*/5, &row_sums));
186 row_sums->type = kTfLiteInt32;
187 row_sums->name = "Rnn_row_sums";
188 row_sums->allocation_type = kTfLiteArenaRwPersistent;
189 int row_sums_dims[2] = {2, num_units};
190 if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) {
191 TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2);
192 row_sums_size->data[0] = row_sums_dims[0];
193 row_sums_size->data[1] = row_sums_dims[1];
194 TF_LITE_ENSURE_OK(
195 context, context->ResizeTensor(context, row_sums, row_sums_size));
196 }
197 }
198 return kTfLiteOk;
199}
200
201TfLiteStatus EvalFloat(const TfLiteTensor* input,
202 const TfLiteTensor* input_weights,
203 const TfLiteTensor* recurrent_weights,
204 const TfLiteTensor* bias, const TfLiteRNNParams* params,
205 TfLiteTensor* hidden_state, TfLiteTensor* output) {
206 const int batch_size = input->dims->data[0];
207 const int num_units = input_weights->dims->data[0];
208 const int input_size = input->dims->data[1];
209 const int output_batch_leading_dim =
210 output->dims->data[output->dims->size - 1];
211
212 // Initialize the pointer to hidden state.
213 float* hidden_state_ptr_batch = GetTensorData<float>(hidden_state);
214 // Initialize the pointer to input and output.
215 const float* input_ptr_batch = GetTensorData<float>(input);
216 float* output_ptr_batch = GetTensorData<float>(output);
217 // Initialize input_weights, recurrent_weights and bias.
218 const float* input_weights_ptr = GetTensorData<float>(input_weights);
219 const float* recurrent_weights_ptr = GetTensorData<float>(recurrent_weights);
220 const float* bias_ptr = GetTensorData<float>(bias);
221
222 kernel_utils::RnnBatchStep(
223 input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr,
224 input_size, num_units, batch_size, output_batch_leading_dim,
225 params->activation, hidden_state_ptr_batch, output_ptr_batch);
226 return kTfLiteOk;
227}
228
229TfLiteStatus EvalHybrid(const TfLiteTensor* input,
230 const TfLiteTensor* input_weights,
231 const TfLiteTensor* recurrent_weights,
232 const TfLiteTensor* bias, const TfLiteRNNParams* params,
233 TfLiteTensor* input_scratch,
234 TfLiteTensor* hidden_state_scratch,
235 TfLiteTensor* scaling_factors,
236 TfLiteTensor* hidden_state, TfLiteTensor* output,
237 TfLiteTensor* zero_points, TfLiteTensor* accum_scratch,
238 TfLiteTensor* row_sums, bool* compute_row_sums) {
239 const int batch_size = input->dims->data[0];
240 const int num_units = input_weights->dims->data[0];
241 const int input_size = input->dims->data[1];
242 const int output_batch_leading_dim =
243 output->dims->data[output->dims->size - 1];
244
245 // Initialize the pointer to hidden state.
246 float* hidden_state_ptr_batch = GetTensorData<float>(hidden_state);
247 // Initialize the pointer to input and output.
248 const float* input_ptr_batch = GetTensorData<float>(input);
249 float* output_ptr_batch = GetTensorData<float>(output);
250 // Initialize input_weights, recurrent_weights and bias.
251 const int8_t* input_weights_ptr = GetTensorData<int8_t>(input_weights);
252 const int8_t* recurrent_weights_ptr =
253 GetTensorData<int8_t>(recurrent_weights);
254 const float* bias_ptr = GetTensorData<float>(bias);
255 // Get the scale of the quantized weights.
256 float input_weights_scale = input_weights->params.scale;
257 float recurrent_weights_scale = recurrent_weights->params.scale;
258 // Initialize temporary storage for quantized values.
259 int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_scratch);
260 int8_t* quantized_hidden_state_ptr =
261 GetTensorData<int8_t>(hidden_state_scratch);
262 float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
263 int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch);
264 int32_t* zero_points_ptr = nullptr;
265 int32_t* row_sums_ptr = nullptr;
266 if (params->asymmetric_quantize_inputs) {
267 zero_points_ptr = GetTensorData<int32_t>(zero_points);
268 row_sums_ptr = GetTensorData<int32_t>(row_sums);
269 }
270 kernel_utils::RnnBatchStep(
271 input_ptr_batch, input_weights_ptr, input_weights_scale,
272 recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size,
273 num_units, batch_size, output_batch_leading_dim, params->activation,
274 quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr,
275 hidden_state_ptr_batch, output_ptr_batch,
276 params->asymmetric_quantize_inputs, zero_points_ptr, accum_scratch_ptr,
277 row_sums_ptr, compute_row_sums);
278 return kTfLiteOk;
279}
280
281TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
282 auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data);
283 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
284 const TfLiteTensor* input;
285 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
286 const TfLiteTensor* input_weights;
287 TF_LITE_ENSURE_OK(
288 context, GetInputSafe(context, node, kWeightsTensor, &input_weights));
289 const TfLiteTensor* recurrent_weights;
290 TF_LITE_ENSURE_OK(
291 context,
292 GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights));
293 const TfLiteTensor* bias;
294 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias));
295 TfLiteTensor* hidden_state =
296 GetVariableInput(context, node, kHiddenStateTensor);
297 TF_LITE_ENSURE(context, hidden_state != nullptr);
298 TfLiteTensor* output;
299 TF_LITE_ENSURE_OK(context,
300 GetOutputSafe(context, node, kOutputTensor, &output));
301
302 // We already checked that weight types are consistent, so branch on one.
303 switch (input_weights->type) {
304 case kTfLiteFloat32:
305 return EvalFloat(input, input_weights, recurrent_weights, bias, params,
306 hidden_state, output);
307 case kTfLiteUInt8:
308 case kTfLiteInt8: {
309 // TODO(mirkov): implement eval with quantized inputs as well.
310 TfLiteTensor* input_quantized;
311 TF_LITE_ENSURE_OK(context,
312 GetTemporarySafe(context, node, 0, &input_quantized));
313 TfLiteTensor* hidden_state_quantized;
314 TF_LITE_ENSURE_OK(
315 context, GetTemporarySafe(context, node, 1, &hidden_state_quantized));
316 TfLiteTensor* scaling_factors;
317 TF_LITE_ENSURE_OK(context,
318 GetTemporarySafe(context, node, 2, &scaling_factors));
319 TfLiteTensor* accum_scratch;
320 TF_LITE_ENSURE_OK(context,
321 GetTemporarySafe(context, node, 3, &accum_scratch));
322 TfLiteTensor* zero_points;
323 TF_LITE_ENSURE_OK(context,
324 GetTemporarySafe(context, node, 4, &zero_points));
325 TfLiteTensor* row_sums;
326 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &row_sums));
327 return EvalHybrid(input, input_weights, recurrent_weights, bias, params,
328 input_quantized, hidden_state_quantized,
329 scaling_factors, hidden_state, output, zero_points,
330 accum_scratch, row_sums, &op_data->compute_row_sums);
331 }
332 default:
333 TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
334 TfLiteTypeGetName(input_weights->type));
335 return kTfLiteError;
336 }
337}
338
339} // namespace rnn
340
341TfLiteRegistration* Register_RNN() {
342 static TfLiteRegistration r = {rnn::Init, rnn::Free, rnn::Prepare, rnn::Eval};
343 return &r;
344}
345
346} // namespace builtin
347} // namespace ops
348} // namespace tflite
349