1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <cstddef> |
16 | #include <cstdint> |
17 | |
18 | #include "tensorflow/lite/c/builtin_op_data.h" |
19 | #include "tensorflow/lite/c/common.h" |
20 | #include "tensorflow/lite/kernels/internal/kernel_utils.h" |
21 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
22 | #include "tensorflow/lite/kernels/kernel_util.h" |
23 | |
24 | namespace tflite { |
25 | namespace ops { |
26 | namespace builtin { |
27 | namespace unidirectional_sequence_rnn { |
28 | |
29 | namespace { |
30 | |
31 | struct OpData { |
32 | int scratch_tensor_index; |
33 | bool compute_row_sums = false; |
34 | }; |
35 | |
36 | } // namespace |
37 | |
38 | // Input tensors. |
39 | constexpr int kInputTensor = 0; |
40 | constexpr int kWeightsTensor = 1; |
41 | constexpr int kRecurrentWeightsTensor = 2; |
42 | constexpr int kBiasTensor = 3; |
43 | constexpr int kHiddenStateTensor = 4; |
44 | |
45 | // Output tensor. |
46 | constexpr int kOutputTensor = 0; |
47 | |
48 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
49 | auto* op_data = new OpData(); |
50 | context->AddTensors(context, /*tensors_to_add=*/6, |
51 | &op_data->scratch_tensor_index); |
52 | return op_data; |
53 | } |
54 | |
55 | void Free(TfLiteContext* context, void* buffer) { |
56 | delete reinterpret_cast<OpData*>(buffer); |
57 | } |
58 | |
59 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
60 | // Check we have all the inputs and outputs we need. |
61 | TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); |
62 | TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); |
63 | |
64 | const TfLiteTensor* input; |
65 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
66 | const TfLiteTensor* input_weights; |
67 | TF_LITE_ENSURE_OK( |
68 | context, GetInputSafe(context, node, kWeightsTensor, &input_weights)); |
69 | const TfLiteTensor* recurrent_weights; |
70 | TF_LITE_ENSURE_OK( |
71 | context, |
72 | GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights)); |
73 | const TfLiteTensor* bias; |
74 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias)); |
75 | const TfLiteTensor* hidden_state; |
76 | TF_LITE_ENSURE_OK( |
77 | context, GetInputSafe(context, node, kHiddenStateTensor, &hidden_state)); |
78 | |
79 | // Check all the parameters of tensor match within themselves and match the |
80 | // input configuration. |
81 | auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data); |
82 | const bool time_major = params->time_major; |
83 | const int batch_size = |
84 | (time_major) ? input->dims->data[1] : input->dims->data[0]; |
85 | const int max_time = |
86 | (time_major) ? input->dims->data[0] : input->dims->data[1]; |
87 | const int num_units = input_weights->dims->data[0]; |
88 | TF_LITE_ENSURE_EQ(context, input->dims->data[2], |
89 | input_weights->dims->data[1]); |
90 | TF_LITE_ENSURE_EQ(context, input_weights->dims->data[0], bias->dims->data[0]); |
91 | TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[0], |
92 | bias->dims->data[0]); |
93 | TF_LITE_ENSURE_EQ(context, recurrent_weights->dims->data[1], |
94 | bias->dims->data[0]); |
95 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32); |
96 | TF_LITE_ENSURE_TYPES_EQ(context, input_weights->type, |
97 | recurrent_weights->type); |
98 | TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2); |
99 | TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size); |
100 | TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units); |
101 | |
102 | TfLiteTensor* output; |
103 | TF_LITE_ENSURE_OK(context, |
104 | GetOutputSafe(context, node, kOutputTensor, &output)); |
105 | |
106 | // Resize output. |
107 | TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3); |
108 | output_size_array->data[0] = (time_major) ? max_time : batch_size; |
109 | output_size_array->data[1] = (time_major) ? batch_size : max_time; |
110 | output_size_array->data[2] = num_units; |
111 | TF_LITE_ENSURE_OK(context, |
112 | context->ResizeTensor(context, output, output_size_array)); |
113 | |
114 | const bool is_hybrid = IsHybridOp(input, input_weights); |
115 | |
116 | // Allocate temporary tensors to store quantized values of input and |
117 | // hidden_state tensors. |
118 | if (is_hybrid) { |
119 | auto* op_data = reinterpret_cast<OpData*>(node->user_data); |
120 | op_data->compute_row_sums = true; |
121 | TfLiteIntArrayFree(node->temporaries); |
122 | node->temporaries = TfLiteIntArrayCreate(6); |
123 | node->temporaries->data[0] = op_data->scratch_tensor_index; |
124 | TfLiteTensor* input_quantized; |
125 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0, |
126 | &input_quantized)); |
127 | input_quantized->type = input_weights->type; |
128 | input_quantized->allocation_type = kTfLiteArenaRw; |
129 | if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { |
130 | TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); |
131 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, |
132 | input_quantized_size)); |
133 | } |
134 | node->temporaries->data[1] = op_data->scratch_tensor_index + 1; |
135 | TfLiteTensor* hidden_state_quantized; |
136 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1, |
137 | &hidden_state_quantized)); |
138 | hidden_state_quantized->type = input_weights->type; |
139 | hidden_state_quantized->allocation_type = kTfLiteArenaRw; |
140 | if (!TfLiteIntArrayEqual(hidden_state_quantized->dims, |
141 | hidden_state->dims)) { |
142 | TfLiteIntArray* hidden_state_quantized_size = |
143 | TfLiteIntArrayCopy(hidden_state->dims); |
144 | TF_LITE_ENSURE_OK(context, |
145 | context->ResizeTensor(context, hidden_state_quantized, |
146 | hidden_state_quantized_size)); |
147 | } |
148 | node->temporaries->data[2] = op_data->scratch_tensor_index + 2; |
149 | TfLiteTensor* scaling_factors; |
150 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2, |
151 | &scaling_factors)); |
152 | scaling_factors->type = kTfLiteFloat32; |
153 | scaling_factors->allocation_type = kTfLiteArenaRw; |
154 | int scaling_dims[1] = {batch_size}; |
155 | if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { |
156 | TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); |
157 | scaling_factors_size->data[0] = batch_size; |
158 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, |
159 | scaling_factors_size)); |
160 | } |
161 | node->temporaries->data[3] = op_data->scratch_tensor_index + 3; |
162 | TfLiteTensor* accum_scratch; |
163 | TF_LITE_ENSURE_OK( |
164 | context, GetTemporarySafe(context, node, /*index=*/3, &accum_scratch)); |
165 | accum_scratch->type = kTfLiteInt32; |
166 | accum_scratch->allocation_type = kTfLiteArenaRw; |
167 | int accum_scratch_dims[2] = {num_units, batch_size}; |
168 | if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, |
169 | accum_scratch_dims)) { |
170 | TfLiteIntArray* accum_scratch_size = TfLiteIntArrayCreate(2); |
171 | accum_scratch_size->data[0] = accum_scratch_dims[0]; |
172 | accum_scratch_size->data[1] = accum_scratch_dims[1]; |
173 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, accum_scratch, |
174 | accum_scratch_size)); |
175 | } |
176 | node->temporaries->data[4] = op_data->scratch_tensor_index + 4; |
177 | TfLiteTensor* zero_points; |
178 | TF_LITE_ENSURE_OK( |
179 | context, GetTemporarySafe(context, node, /*index=*/4, &zero_points)); |
180 | zero_points->type = kTfLiteInt32; |
181 | zero_points->allocation_type = kTfLiteArenaRw; |
182 | int zero_points_dims[1] = {batch_size}; |
183 | if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { |
184 | TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); |
185 | zero_points_size->data[0] = batch_size; |
186 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, |
187 | zero_points_size)); |
188 | } |
189 | node->temporaries->data[5] = op_data->scratch_tensor_index + 5; |
190 | TfLiteTensor* row_sums; |
191 | TF_LITE_ENSURE_OK(context, |
192 | GetTemporarySafe(context, node, /*index=*/5, &row_sums)); |
193 | row_sums->type = kTfLiteInt32; |
194 | row_sums->allocation_type = kTfLiteArenaRwPersistent; |
195 | int row_sums_dims[2] = {2, num_units}; |
196 | if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) { |
197 | TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2); |
198 | row_sums_size->data[0] = row_sums_dims[0]; |
199 | row_sums_size->data[1] = row_sums_dims[1]; |
200 | TF_LITE_ENSURE_OK( |
201 | context, context->ResizeTensor(context, row_sums, row_sums_size)); |
202 | } |
203 | } |
204 | return kTfLiteOk; |
205 | } |
206 | |
207 | TfLiteStatus EvalFloat(const TfLiteTensor* input, |
208 | const TfLiteTensor* input_weights, |
209 | const TfLiteTensor* recurrent_weights, |
210 | const TfLiteTensor* bias, |
211 | const TfLiteSequenceRNNParams* params, |
212 | TfLiteTensor* hidden_state, TfLiteTensor* output) { |
213 | // Initialize the pointer bias. |
214 | const float* bias_ptr = GetTensorData<float>(bias); |
215 | |
216 | const bool time_major = params->time_major; |
217 | const int batch_size = |
218 | (time_major) ? input->dims->data[1] : input->dims->data[0]; |
219 | const int max_time = |
220 | (time_major) ? input->dims->data[0] : input->dims->data[1]; |
221 | const int num_units = input_weights->dims->data[0]; |
222 | const int input_size = input->dims->data[2]; |
223 | |
224 | // Initialize input_weights and recurrent_weights. |
225 | const float* input_weights_ptr = GetTensorData<float>(input_weights); |
226 | const float* recurrent_weights_ptr = GetTensorData<float>(recurrent_weights); |
227 | |
228 | if (time_major) { |
229 | // Initialize the pointer to hidden state. |
230 | float* hidden_state_ptr_batch = GetTensorData<float>(hidden_state); |
231 | // Unroll the sequence and use batch operations for efficiency. |
232 | for (int s = 0; s < max_time; s++) { |
233 | // Initialize the pointer to input and output. |
234 | const float* input_ptr_batch = |
235 | GetTensorData<float>(input) + s * input_size * batch_size; |
236 | float* output_ptr_batch = |
237 | GetTensorData<float>(output) + s * num_units * batch_size; |
238 | |
239 | kernel_utils::RnnBatchStep( |
240 | input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr, |
241 | input_size, num_units, batch_size, num_units, params->activation, |
242 | hidden_state_ptr_batch, output_ptr_batch); |
243 | } |
244 | } else { |
245 | // For each batch |
246 | for (int b = 0; b < batch_size; b++) { |
247 | // Initialize the pointer to hidden state. |
248 | float* hidden_state_ptr_batch = |
249 | GetTensorData<float>(hidden_state) + b * num_units; |
250 | for (int s = 0; s < max_time; s++) { |
251 | // Initialize the pointer to input and output. |
252 | const float* input_ptr_batch = GetTensorData<float>(input) + |
253 | b * input_size * max_time + |
254 | s * input_size; |
255 | float* output_ptr_batch = GetTensorData<float>(output) + |
256 | b * num_units * max_time + s * num_units; |
257 | |
258 | kernel_utils::RnnBatchStep( |
259 | input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr, |
260 | input_size, num_units, /*batch_size=*/1, num_units, |
261 | params->activation, hidden_state_ptr_batch, output_ptr_batch); |
262 | } |
263 | } |
264 | } |
265 | return kTfLiteOk; |
266 | } |
267 | |
268 | TfLiteStatus EvalHybrid( |
269 | const TfLiteTensor* input, const TfLiteTensor* input_weights, |
270 | const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias, |
271 | const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch, |
272 | TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors, |
273 | TfLiteTensor* hidden_state, TfLiteTensor* output, TfLiteTensor* zero_points, |
274 | TfLiteTensor* accum_scratch, TfLiteTensor* row_sums, |
275 | bool* compute_row_sums) { |
276 | const bool time_major = params->time_major; |
277 | const int batch_size = |
278 | (time_major) ? input->dims->data[1] : input->dims->data[0]; |
279 | const int max_time = |
280 | (time_major) ? input->dims->data[0] : input->dims->data[1]; |
281 | const int num_units = input_weights->dims->data[0]; |
282 | const int input_size = input->dims->data[2]; |
283 | |
284 | // Initialize the pointer bias. |
285 | const float* bias_ptr = GetTensorData<float>(bias); |
286 | |
287 | // Initialize input_weights, recurrent_weights, and temporary storage for |
288 | // quantized values. |
289 | const int8_t* input_weights_ptr = GetTensorData<int8_t>(input_weights); |
290 | const int8_t* recurrent_weights_ptr = |
291 | GetTensorData<int8_t>(recurrent_weights); |
292 | int8_t* quantized_input_ptr = GetTensorData<int8_t>(input_scratch); |
293 | int8_t* quantized_hidden_state_ptr = |
294 | GetTensorData<int8_t>(hidden_state_scratch); |
295 | |
296 | // Get the scale of the quantized weights. |
297 | float input_weights_scale = input_weights->params.scale; |
298 | float recurrent_weights_scale = recurrent_weights->params.scale; |
299 | float* scaling_factors_ptr = GetTensorData<float>(scaling_factors); |
300 | int32_t* accum_scratch_ptr = GetTensorData<int32_t>(accum_scratch); |
301 | int32_t* zero_points_ptr = nullptr; |
302 | int32_t* row_sums_ptr = nullptr; |
303 | |
304 | if (params->asymmetric_quantize_inputs) { |
305 | zero_points_ptr = GetTensorData<int32_t>(zero_points); |
306 | row_sums_ptr = GetTensorData<int32_t>(row_sums); |
307 | } |
308 | |
309 | if (time_major) { |
310 | // Initialize the pointer to hidden state. |
311 | float* hidden_state_ptr_batch = GetTensorData<float>(hidden_state); |
312 | // Unroll the sequence and use batch operations for efficiency. |
313 | for (int s = 0; s < max_time; s++) { |
314 | // Initialize the pointer to input and output. |
315 | const float* input_ptr_batch = |
316 | GetTensorData<float>(input) + s * input_size * batch_size; |
317 | float* output_ptr_batch = |
318 | GetTensorData<float>(output) + s * num_units * batch_size; |
319 | |
320 | kernel_utils::RnnBatchStep( |
321 | input_ptr_batch, input_weights_ptr, input_weights_scale, |
322 | recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, |
323 | num_units, batch_size, num_units, params->activation, |
324 | quantized_input_ptr, quantized_hidden_state_ptr, scaling_factors_ptr, |
325 | hidden_state_ptr_batch, output_ptr_batch, |
326 | params->asymmetric_quantize_inputs, zero_points_ptr, |
327 | accum_scratch_ptr, row_sums_ptr, compute_row_sums); |
328 | } |
329 | } else { |
330 | // For each batch |
331 | for (int b = 0; b < batch_size; b++) { |
332 | // Initialize the pointer to hidden state. |
333 | float* hidden_state_ptr_batch = |
334 | GetTensorData<float>(hidden_state) + b * num_units; |
335 | for (int s = 0; s < max_time; s++) { |
336 | // Initialize the pointer to input and output. |
337 | const float* input_ptr_batch = GetTensorData<float>(input) + |
338 | b * input_size * max_time + |
339 | s * input_size; |
340 | float* output_ptr_batch = GetTensorData<float>(output) + |
341 | b * num_units * max_time + s * num_units; |
342 | kernel_utils::RnnBatchStep( |
343 | input_ptr_batch, input_weights_ptr, input_weights_scale, |
344 | recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, |
345 | input_size, num_units, /*batch_size=*/1, num_units, |
346 | params->activation, quantized_input_ptr, quantized_hidden_state_ptr, |
347 | scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch, |
348 | params->asymmetric_quantize_inputs, zero_points_ptr, |
349 | accum_scratch_ptr, row_sums_ptr, compute_row_sums); |
350 | } |
351 | } |
352 | } |
353 | return kTfLiteOk; |
354 | } |
355 | |
356 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
357 | auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data); |
358 | const TfLiteTensor* input; |
359 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
360 | const TfLiteTensor* input_weights; |
361 | TF_LITE_ENSURE_OK( |
362 | context, GetInputSafe(context, node, kWeightsTensor, &input_weights)); |
363 | const TfLiteTensor* recurrent_weights; |
364 | TF_LITE_ENSURE_OK( |
365 | context, |
366 | GetInputSafe(context, node, kRecurrentWeightsTensor, &recurrent_weights)); |
367 | const TfLiteTensor* bias; |
368 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kBiasTensor, &bias)); |
369 | // The hidden_state is a variable input tensor that can be modified. |
370 | TfLiteTensor* hidden_state = |
371 | GetVariableInput(context, node, kHiddenStateTensor); |
372 | TF_LITE_ENSURE(context, hidden_state != nullptr); |
373 | TfLiteTensor* output; |
374 | TF_LITE_ENSURE_OK(context, |
375 | GetOutputSafe(context, node, kOutputTensor, &output)); |
376 | |
377 | switch (input_weights->type) { |
378 | case kTfLiteFloat32: |
379 | return EvalFloat(input, input_weights, recurrent_weights, bias, params, |
380 | hidden_state, output); |
381 | case kTfLiteUInt8: |
382 | case kTfLiteInt8: { |
383 | // TODO(mirkov): implement eval with quantized inputs as well. |
384 | auto* op_data = reinterpret_cast<OpData*>(node->user_data); |
385 | TfLiteTensor* input_quantized; |
386 | TF_LITE_ENSURE_OK(context, |
387 | GetTemporarySafe(context, node, 0, &input_quantized)); |
388 | TfLiteTensor* hidden_state_quantized; |
389 | TF_LITE_ENSURE_OK( |
390 | context, GetTemporarySafe(context, node, 1, &hidden_state_quantized)); |
391 | TfLiteTensor* scaling_factors; |
392 | TF_LITE_ENSURE_OK(context, |
393 | GetTemporarySafe(context, node, 2, &scaling_factors)); |
394 | TfLiteTensor* accum_scratch; |
395 | TF_LITE_ENSURE_OK(context, |
396 | GetTemporarySafe(context, node, 3, &accum_scratch)); |
397 | TfLiteTensor* zero_points; |
398 | TF_LITE_ENSURE_OK(context, |
399 | GetTemporarySafe(context, node, 4, &zero_points)); |
400 | TfLiteTensor* row_sums; |
401 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &row_sums)); |
402 | return EvalHybrid(input, input_weights, recurrent_weights, bias, params, |
403 | input_quantized, hidden_state_quantized, |
404 | scaling_factors, hidden_state, output, zero_points, |
405 | accum_scratch, row_sums, &op_data->compute_row_sums); |
406 | } |
407 | default: |
408 | TF_LITE_KERNEL_LOG(context, "Type %d not currently supported." , |
409 | TfLiteTypeGetName(input_weights->type)); |
410 | return kTfLiteError; |
411 | } |
412 | } |
413 | |
414 | } // namespace unidirectional_sequence_rnn |
415 | |
416 | TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN() { |
417 | static TfLiteRegistration r = { |
418 | unidirectional_sequence_rnn::Init, unidirectional_sequence_rnn::Free, |
419 | unidirectional_sequence_rnn::Prepare, unidirectional_sequence_rnn::Eval}; |
420 | return &r; |
421 | } |
422 | |
423 | } // namespace builtin |
424 | } // namespace ops |
425 | } // namespace tflite |
426 | |