1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <cstddef> |
16 | #include <cstdint> |
17 | |
18 | #include "tensorflow/lite/c/builtin_op_data.h" |
19 | #include "tensorflow/lite/c/common.h" |
20 | #include "tensorflow/lite/kernels/internal/kernel_utils.h" |
21 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
22 | #include "tensorflow/lite/kernels/kernel_util.h" |
23 | |
24 | namespace tflite { |
25 | namespace ops { |
26 | namespace builtin { |
27 | namespace rnn { |
28 | |
29 | namespace { |
30 | |
31 | struct OpData { |
32 | int scratch_tensor_index; |
33 | bool compute_row_sums = false; |
34 | }; |
35 | |
36 | } // namespace |
37 | |
38 | constexpr int kInputTensor = 0; |
39 | constexpr int kWeightsTensor = 1; |
40 | constexpr int kRecurrentWeightsTensor = 2; |
41 | constexpr int kBiasTensor = 3; |
42 | constexpr int kHiddenStateTensor = 4; |
43 | |
44 | // Output tensor. |
45 | constexpr int kOutputTensor = 0; |
46 | |
47 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
48 | auto* op_data = new OpData(); |
49 | context->AddTensors(context, /*tensors_to_add=*/6, |
50 | &op_data->scratch_tensor_index); |
51 | return op_data; |
52 | } |
53 | |
54 | void Free(TfLiteContext* context, void* buffer) { |
55 | delete reinterpret_cast<OpData*>(buffer); |
56 | } |
57 | |
58 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
59 | // Check we have all the inputs and outputs we need. |
60 | TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); |
61 | TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); |
62 | |
63 | const TfLiteTensor* input; |
64 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
65 | const TfLiteTensor* input_weights; |
66 | TF_LITE_ENSURE_OK( |
67 | context, GetInputSafe(context, node, kWeightsTensor, &input_weights)); |
68 | const TfLiteTensor* recurrent_weights; |
69 | TF_LITE_ENSURE_OK( |
70 | context, |
71 | GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights)); |
72 | const TfLiteTensor* bias; |
73 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias)); |
74 | const TfLiteTensor* hidden_state; |
75 | TF_LITE_ENSURE_OK( |
76 | context, GetInputSafe(context, node, kHiddenStateTensor, &hidden_state)); |
77 | |
78 | // Check all the parameters of tensor match within themselves and match the |
79 | // input configuration. |
80 | const int batch_size = input->dims->data[0]; |
81 | const int num_units = input_weights->dims->data[0]; |
82 | TF_LITE_ENSURE_EQ(context, input->dims->data[1], |
83 | input_weights->dims->data[1]); |
84 | TF_LITE_ENSURE_EQ(context, input_weights->dims->data[0], bias->dims->data[0]); |
85 | TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[0], |
86 | bias->dims->data[0]); |
87 | TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[1], |
88 | bias->dims->data[0]); |
89 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); |
90 | TF_LITE_ENSURE_TYPES_EQ(context, input_weights->type, |
91 | recurrent_weights->type); |
92 | TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2); |
93 | TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size); |
94 | TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units); |
95 | |
96 | TfLiteTensor* output; |
97 | TF_LITE_ENSURE_OK(context, |
98 | GetOutputSafe(context, node, kOutputTensor, &output)); |
99 | |
100 | // Resize output. |
101 | TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); |
102 | output_size_array->data[0] = batch_size; |
103 | output_size_array->data[1] = num_units; |
104 | TF_LITE_ENSURE_OK(context, |
105 | context->ResizeTensor(context, output, output_size_array)); |
106 | |
107 | const bool is_hybrid = IsHybridOp(input, input_weights); |
108 | |
109 | // Allocate temporary tensors to store quantized values of input and |
110 | // hidden_state tensors. |
111 | if (is_hybrid) { |
112 | auto* op_data = reinterpret_cast<OpData*>(node->user_data); |
113 | op_data->compute_row_sums = true; |
114 | TfLiteIntArrayFree(node->temporaries); |
115 | node->temporaries = TfLiteIntArrayCreate(6); |
116 | node->temporaries->data[0] = op_data->scratch_tensor_index; |
117 | TfLiteTensor* input_quantized; |
118 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0, |
119 | &input_quantized)); |
120 | input_quantized->type = input_weights->type; |
121 | input_quantized->allocation_type = kTfLiteArenaRw; |
122 | if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { |
123 | TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); |
124 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, |
125 | input_quantized_size)); |
126 | } |
127 | node->temporaries->data[1] = op_data->scratch_tensor_index + 1; |
128 | TfLiteTensor* hidden_state_quantized; |
129 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1, |
130 | &hidden_state_quantized)); |
131 | hidden_state_quantized->type = input_weights->type; |
132 | hidden_state_quantized->allocation_type = kTfLiteArenaRw; |
133 | if (!TfLiteIntArrayEqual(hidden_state_quantized->dims, |
134 | hidden_state->dims)) { |
135 | TfLiteIntArray* hidden_state_quantized_size = |
136 | TfLiteIntArrayCopy(hidden_state->dims); |
137 | TF_LITE_ENSURE_OK(context, |
138 | context->ResizeTensor(context, hidden_state_quantized, |
139 | hidden_state_quantized_size)); |
140 | } |
141 | node->temporaries->data[2] = op_data->scratch_tensor_index + 2; |
142 | TfLiteTensor* scaling_factors; |
143 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2, |
144 | &scaling_factors)); |
145 | scaling_factors->type = kTfLiteFloat32; |
146 | scaling_factors->allocation_type = kTfLiteArenaRw; |
147 | int scaling_dims[1] = {batch_size}; |
148 | if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { |
149 | TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); |
150 | scaling_factors_size->data[0] = batch_size; |
151 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, |
152 | scaling_factors_size)); |
153 | } |
154 | node->temporaries->data[3] = op_data->scratch_tensor_index + 3; |
155 | TfLiteTensor* accum_scratch; |
156 | TF_LITE_ENSURE_OK( |
157 | context, GetTemporarySafe(context, node, /*index=*/3, &accum_scratch)); |
158 | accum_scratch->type = kTfLiteInt32; |
159 | accum_scratch->allocation_type = kTfLiteArenaRw; |
160 | int accum_scratch_dims[2] = {num_units, batch_size}; |
161 | if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, |
162 | accum_scratch_dims)) { |
163 | TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2); |
164 | accum_scratch_size->data[0] = accum_scratch_dims[0]; |
165 | accum_scratch_size->data[1] = accum_scratch_dims[1]; |
166 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch, |
167 | accum_scratch_size)); |
168 | } |
169 | node->temporaries->data[4] = op_data->scratch_tensor_index + 4; |
170 | TfLiteTensor* zero_points; |
171 | TF_LITE_ENSURE_OK( |
172 | context, GetTemporarySafe(context, node, /*index=*/4, &zero_points)); |
173 | zero_points->type = kTfLiteInt32; |
174 | zero_points->allocation_type = kTfLiteArenaRw; |
175 | int zero_points_dims[1] = {batch_size}; |
176 | if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { |
177 | TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); |
178 | zero_points_size->data[0] = batch_size; |
179 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, |
180 | zero_points_size)); |
181 | } |
182 | node->temporaries->data[5] = op_data->scratch_tensor_index + 5; |
183 | TfLiteTensor* row_sums; |
184 | TF_LITE_ENSURE_OK(context, |
185 | GetTemporarySafe(context, node, /*index=*/5, &row_sums)); |
186 | row_sums->type = kTfLiteInt32; |
187 | row_sums->name = "Rnn_row_sums" ; |
188 | row_sums->allocation_type = kTfLiteArenaRwPersistent; |
189 | int row_sums_dims[2] = {2, num_units}; |
190 | if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) { |
191 | TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2); |
192 | row_sums_size->data[0] = row_sums_dims[0]; |
193 | row_sums_size->data[1] = row_sums_dims[1]; |
194 | TF_LITE_ENSURE_OK( |
195 | context, context->ResizeTensor(context, row_sums, row_sums_size)); |
196 | } |
197 | } |
198 | return kTfLiteOk; |
199 | } |
200 | |
201 | TfLiteStatus EvalFloat(const TfLiteTensor* input, |
202 | const TfLiteTensor* input_weights, |
203 | const TfLiteTensor* recurrent_weights, |
204 | const TfLiteTensor* bias, const TfLiteRNNParams* params, |
205 | TfLiteTensor* hidden_state, TfLiteTensor* output) { |
206 | const int batch_size = input->dims->data[0]; |
207 | const int num_units = input_weights->dims->data[0]; |
208 | const int input_size = input->dims->data[1]; |
209 | const int output_batch_leading_dim = |
210 | output->dims->data[output->dims->size - 1]; |
211 | |
212 | // Initialize the pointer to hidden state. |
213 | float* hidden_state_ptr_batch = GetTensorData<float>(hidden_state); |
214 | // Initialize the pointer to input and output. |
215 | const float* input_ptr_batch = GetTensorData<float>(input); |
216 | float* output_ptr_batch = GetTensorData<float>(output); |
217 | // Initialize input_weights, recurrent_weights and bias. |
218 | const float* input_weights_ptr = GetTensorData<float>(input_weights); |
219 | const float* recurrent_weights_ptr = GetTensorData<float>(recurrent_weights); |
220 | const float* bias_ptr = GetTensorData<float>(bias); |
221 | |
222 | kernel_utils::RnnBatchStep( |
223 | input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr, |
224 | input_size, num_units, batch_size, output_batch_leading_dim, |
225 | params->activation, hidden_state_ptr_batch, output_ptr_batch); |
226 | return kTfLiteOk; |
227 | } |
228 | |
229 | TfLiteStatus EvalHybrid(const TfLiteTensor* input, |
230 | const TfLiteTensor* input_weights, |
231 | const TfLiteTensor* recurrent_weights, |
232 | const TfLiteTensor* bias, const TfLiteRNNParams* params, |
233 | TfLiteTensor* input_scratch, |
234 | TfLiteTensor* hidden_state_scratch, |
235 | TfLiteTensor* scaling_factors, |
236 | TfLiteTensor* hidden_state, TfLiteTensor* output, |
237 | TfLiteTensor* zero_points, TfLiteTensor* accum_scratch, |
238 | TfLiteTensor* row_sums, bool* compute_row_sums) { |
239 | const int batch_size = input->dims->data[0]; |
240 | const int num_units = input_weights->dims->data[0]; |
241 | const int input_size = input->dims->data[1]; |
242 | const int output_batch_leading_dim = |
243 | output->dims->data[output->dims->size - 1]; |
244 | |
245 | // Initialize the pointer to hidden state. |
246 | float* hidden_state_ptr_batch = GetTensorData<float>(hidden_state); |
247 | // Initialize the pointer to input and output. |
248 | const float* input_ptr_batch = GetTensorData<float>(input); |
249 | float* output_ptr_batch = GetTensorData<float>(output); |
250 | // Initialize input_weights, recurrent_weights and bias. |
251 | const int8_t* input_weights_ptr = GetTensorData<int8_t>(input_weights); |
252 | const int8_t* recurrent_weights_ptr = |
253 | GetTensorData<int8_t>(recurrent_weights); |
254 | const float* bias_ptr = GetTensorData<float>(bias); |
255 | // Get the scale of the quantized weights. |
256 | float input_weights_scale = input_weights->params.scale; |
257 | float recurrent_weights_scale = recurrent_weights->params.scale; |
258 | // Initialize temporary storage for quantized values. |
259 | int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_scratch); |
260 | int8_t* quantized_hidden_state_ptr = |
261 | GetTensorData<int8_t>(hidden_state_scratch); |
262 | float* scaling_factors_ptr = GetTensorData<float>(scaling_factors); |
263 | int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch); |
264 | int32_t* zero_points_ptr = nullptr; |
265 | int32_t* row_sums_ptr = nullptr; |
266 | if (params->asymmetric_quantize_inputs) { |
267 | zero_points_ptr = GetTensorData<int32_t>(zero_points); |
268 | row_sums_ptr = GetTensorData<int32_t>(row_sums); |
269 | } |
270 | kernel_utils::RnnBatchStep( |
271 | input_ptr_batch, input_weights_ptr, input_weights_scale, |
272 | recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, |
273 | num_units, batch_size, output_batch_leading_dim, params->activation, |
274 | quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr, |
275 | hidden_state_ptr_batch, output_ptr_batch, |
276 | params->asymmetric_quantize_inputs, zero_points_ptr, accum_scratch_ptr, |
277 | row_sums_ptr, compute_row_sums); |
278 | return kTfLiteOk; |
279 | } |
280 | |
281 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
282 | auto* params = reinterpret_cast<TfLiteRNNParams*>(node->builtin_data); |
283 | auto* op_data = reinterpret_cast<OpData*>(node->user_data); |
284 | const TfLiteTensor* input; |
285 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
286 | const TfLiteTensor* input_weights; |
287 | TF_LITE_ENSURE_OK( |
288 | context, GetInputSafe(context, node, kWeightsTensor, &input_weights)); |
289 | const TfLiteTensor* recurrent_weights; |
290 | TF_LITE_ENSURE_OK( |
291 | context, |
292 | GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights)); |
293 | const TfLiteTensor* bias; |
294 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias)); |
295 | TfLiteTensor* hidden_state = |
296 | GetVariableInput(context, node, kHiddenStateTensor); |
297 | TF_LITE_ENSURE(context, hidden_state != nullptr); |
298 | TfLiteTensor* output; |
299 | TF_LITE_ENSURE_OK(context, |
300 | GetOutputSafe(context, node, kOutputTensor, &output)); |
301 | |
302 | // We already checked that weight types are consistent, so branch on one. |
303 | switch (input_weights->type) { |
304 | case kTfLiteFloat32: |
305 | return EvalFloat(input, input_weights, recurrent_weights, bias, params, |
306 | hidden_state, output); |
307 | case kTfLiteUInt8: |
308 | case kTfLiteInt8: { |
309 | // TODO(mirkov): implement eval with quantized inputs as well. |
310 | TfLiteTensor* input_quantized; |
311 | TF_LITE_ENSURE_OK(context, |
312 | GetTemporarySafe(context, node, 0, &input_quantized)); |
313 | TfLiteTensor* hidden_state_quantized; |
314 | TF_LITE_ENSURE_OK( |
315 | context, GetTemporarySafe(context, node, 1, &hidden_state_quantized)); |
316 | TfLiteTensor* scaling_factors; |
317 | TF_LITE_ENSURE_OK(context, |
318 | GetTemporarySafe(context, node, 2, &scaling_factors)); |
319 | TfLiteTensor* accum_scratch; |
320 | TF_LITE_ENSURE_OK(context, |
321 | GetTemporarySafe(context, node, 3, &accum_scratch)); |
322 | TfLiteTensor* zero_points; |
323 | TF_LITE_ENSURE_OK(context, |
324 | GetTemporarySafe(context, node, 4, &zero_points)); |
325 | TfLiteTensor* row_sums; |
326 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &row_sums)); |
327 | return EvalHybrid(input, input_weights, recurrent_weights, bias, params, |
328 | input_quantized, hidden_state_quantized, |
329 | scaling_factors, hidden_state, output, zero_points, |
330 | accum_scratch, row_sums, &op_data->compute_row_sums); |
331 | } |
332 | default: |
333 | TF_LITE_KERNEL_LOG(context, "Type %s not currently supported." , |
334 | TfLiteTypeGetName(input_weights->type)); |
335 | return kTfLiteError; |
336 | } |
337 | } |
338 | |
339 | } // namespace rnn |
340 | |
341 | TfLiteRegistration* Register_RNN() { |
342 | static TfLiteRegistration r = {rnn::Init, rnn::Free, rnn::Prepare, rnn::Eval}; |
343 | return &r; |
344 | } |
345 | |
346 | } // namespace builtin |
347 | } // namespace ops |
348 | } // namespace tflite |
349 | |