1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <math.h>
17#include <stddef.h>
18#include <stdint.h>
19#include <string.h>
20
21#include <algorithm>
22#include <complex>
23
24#include "third_party/fft2d/fft2d.h"
25#include "ruy/profiler/instrumentation.h" // from @ruy
26#include "tensorflow/lite/c/common.h"
27#include "tensorflow/lite/kernels/internal/tensor.h"
28#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29#include "tensorflow/lite/kernels/internal/types.h"
30#include "tensorflow/lite/kernels/kernel_util.h"
31
32namespace tflite {
33namespace ops {
34namespace builtin {
35namespace rfft2d {
36
37using std::complex;
38
39constexpr int kInputTensor = 0;
40constexpr int kFftLengthTensor = 1;
41constexpr int kOutputTensor = 0;
42constexpr int kFftIntegerWorkingAreaTensor = 0;
43constexpr int kFftDoubleWorkingAreaTensor = 1;
44constexpr int kTensorNotAllocated = -1;
45
46struct OpData {
47 // IDs are the arbitrary identifiers used by TF Lite to identify and access
48 // memory buffers.
49 int fft_integer_working_area_id = kTensorNotAllocated;
50 int fft_double_working_area_id = kTensorNotAllocated;
51};
52
53bool IsPowerOfTwo(uint32_t v) { return v && !(v & (v - 1)); }
54
55static TfLiteStatus InitTemporaryTensors(TfLiteContext* context,
56 TfLiteNode* node) {
57 OpData* data = reinterpret_cast<OpData*>(node->user_data);
58 // The prepare function may be executed multiple times. But temporary tensors
59 // only need to be initiated once.
60 if (data->fft_integer_working_area_id != kTensorNotAllocated &&
61 data->fft_double_working_area_id != kTensorNotAllocated) {
62 return kTfLiteOk;
63 }
64
65 TfLiteIntArrayFree(node->temporaries);
66 // Create two temporary tensors.
67 node->temporaries = TfLiteIntArrayCreate(2);
68 int first_new_index;
69 TF_LITE_ENSURE_STATUS(context->AddTensors(context, 2, &first_new_index));
70 node->temporaries->data[kFftIntegerWorkingAreaTensor] = first_new_index;
71 data->fft_integer_working_area_id = first_new_index;
72 node->temporaries->data[kFftDoubleWorkingAreaTensor] = first_new_index + 1;
73 data->fft_double_working_area_id = first_new_index + 1;
74
75 // Set up FFT integer working area buffer.
76 TfLiteTensor* fft_integer_working_area;
77 TF_LITE_ENSURE_OK(
78 context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
79 &fft_integer_working_area));
80 fft_integer_working_area->type = kTfLiteInt32;
81 // If fft_length is not a constant tensor, fft_integer_working_area will be
82 // set to dynamic later in Prepare.
83 fft_integer_working_area->allocation_type = kTfLiteArenaRw;
84
85 // Set up FFT double working area buffer.
86 TfLiteTensor* fft_double_working_area;
87 TF_LITE_ENSURE_OK(context,
88 GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
89 &fft_double_working_area));
90 // fft_double_working_area is a double tensor. Ideally, double should be
91 // added into tflite data types. However, since fft_double_working_area is a
92 // temporary tensor, and there are no ops having double input/output tensors
93 // in tflite at this point, adding double as a tflite data type may confuse
94 // users that double is supported. As a results, kTfLiteInt64 is used here
95 // for memory allocation. And it will be cast into double in Eval when being
96 // used.
97 fft_double_working_area->type = kTfLiteInt64;
98 // If fft_length is not a constant tensor, fft_double_working_area will be
99 // set to dynamic later in Prepare.
100 fft_double_working_area->allocation_type = kTfLiteArenaRw;
101
102 return kTfLiteOk;
103}
104
105TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context,
106 TfLiteNode* node) {
107 const TfLiteTensor* input;
108 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
109 const int num_dims = NumDimensions(input);
110 TF_LITE_ENSURE(context, num_dims >= 2);
111 const TfLiteTensor* fft_length;
112 TF_LITE_ENSURE_OK(context,
113 GetInputSafe(context, node, kFftLengthTensor, &fft_length));
114 const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
115 // The lib, fft2d, can only handle fft_lengths of power of 2.
116 TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[0]));
117 TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[1]));
118
119 int fft_height, fft_width;
120 fft_height = fft_length_data[0];
121 fft_width = fft_length_data[1];
122 int fft_working_length = std::max(fft_height, fft_width / 2);
123 int half_fft_working_length = fft_working_length / 2;
124
125 // Resize output tensor.
126 TfLiteTensor* output;
127 TF_LITE_ENSURE_OK(context,
128 GetOutputSafe(context, node, kOutputTensor, &output));
129 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
130 output_shape->data[num_dims - 2] = fft_length_data[0];
131 output_shape->data[num_dims - 1] = fft_length_data[1] / 2 + 1;
132 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape));
133
134 // Resize temporary tensors, fft_integer_working_area.
135 TfLiteTensor* fft_integer_working_area;
136 TF_LITE_ENSURE_OK(
137 context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
138 &fft_integer_working_area));
139 TfLiteIntArray* fft_integer_working_area_shape = TfLiteIntArrayCreate(1);
140 fft_integer_working_area_shape->data[0] =
141 2 + static_cast<int>(sqrt(fft_working_length));
142 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_integer_working_area,
143 fft_integer_working_area_shape));
144
145 // Resize temporary tensors, fft_double_working_area.
146 TfLiteTensor* fft_double_working_area;
147 TF_LITE_ENSURE_OK(context,
148 GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
149 &fft_double_working_area));
150 TfLiteIntArray* fft_double_working_area_shape = TfLiteIntArrayCreate(1);
151 fft_double_working_area_shape->data[0] =
152 half_fft_working_length + fft_width / 4;
153 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_double_working_area,
154 fft_double_working_area_shape));
155
156 return kTfLiteOk;
157}
158
159void* Init(TfLiteContext* context, const char* buffer, size_t length) {
160 auto* data = new OpData;
161 return data;
162}
163
164void Free(TfLiteContext* context, void* buffer) {
165 delete reinterpret_cast<OpData*>(buffer);
166}
167
168TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
169 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
170 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
171
172 // Check type and shape of the input tensor
173 const TfLiteTensor* input;
174 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
175 TF_LITE_ENSURE(context, NumDimensions(input) >= 2);
176 if (input->type != kTfLiteFloat32) {
177 TF_LITE_KERNEL_LOG(context,
178 "Type '%s' for input is not supported by rfft2d.",
179 TfLiteTypeGetName(input->type));
180 return kTfLiteError;
181 }
182
183 // Check type and shape of the fft_length tensor
184 const TfLiteTensor* fft_length;
185 TF_LITE_ENSURE_OK(context,
186 GetInputSafe(context, node, kFftLengthTensor, &fft_length));
187 const RuntimeShape fft_length_shape = GetTensorShape(fft_length);
188
189 TF_LITE_ENSURE_EQ(context, NumDimensions(fft_length), 1);
190 TF_LITE_ENSURE_EQ(context, fft_length_shape.Dims(0), 2);
191 if (fft_length->type != kTfLiteInt32) {
192 TF_LITE_KERNEL_LOG(context,
193 "Type '%s' for fft_length is not supported by rfft2d.",
194 TfLiteTypeGetName(fft_length->type));
195 return kTfLiteError;
196 }
197
198 // Setup temporary tensors for fft computation.
199 TF_LITE_ENSURE_STATUS(InitTemporaryTensors(context, node));
200
201 // Set output type
202 TfLiteTensor* output;
203 TF_LITE_ENSURE_OK(context,
204 GetOutputSafe(context, node, kOutputTensor, &output));
205 output->type = kTfLiteComplex64;
206
207 // Exit early if fft_length is a non-const tensor. Set output tensor and
208 // temporary tensors to dynamic, so that their tensor sizes can be determined
209 // in Eval.
210 if (!IsConstantTensor(fft_length)) {
211 TfLiteTensor* fft_integer_working_area;
212 TF_LITE_ENSURE_OK(
213 context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
214 &fft_integer_working_area));
215 TfLiteTensor* fft_double_working_area;
216 TF_LITE_ENSURE_OK(
217 context, GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
218 &fft_double_working_area));
219 SetTensorToDynamic(fft_integer_working_area);
220 SetTensorToDynamic(fft_double_working_area);
221 SetTensorToDynamic(output);
222 return kTfLiteOk;
223 }
224
225 TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node));
226 return kTfLiteOk;
227}
228
229// Reorder the result so that it matches the pattern of tf.signal.rfft2d.
230// In tf.signal.fft2d the frequency matrix of a 4x4 input is
231// [[F(0, 0), F(0, 1/4), F(0, 2/4)],
232// [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)],
233// [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)],
234// [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]]
235// While in rdft2d, the frequency matrix of a 4x4 input is
236// [[(F(0, 0), F(0, -2/4)) F(0, -1/4), 0],
237// [ F(-1/4, 0), F(-1/4, -1/4), 0],
238// [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0],
239// [ j*F(-3/4, -2/4), F(-3/4, -1/4), 0]]
240// Since real fft has the property that
241// Real(u,v) = Real(-u, -v)
242// Img(u,v) = - Img(-u, -v)
243// Result of rdft2d can be reordered and match the pattern of tf.signal.rfft2d.
244// For example,
245// Real(-3/4, 0) = Real(1/4, 0) = Real(-1/4, 0)
246// Img(-3/4, 0) = Img(1/4, 0) = -Img(-1/4, 0)
247void Rfft2dReorder(int fft_height, int fft_width, double** fft_input_output) {
248 int fft_height_half;
249 ruy::profiler::ScopeLabel label("Rfft2dReorder");
250 double real, img;
251
252 fft_height_half = fft_height >> 1;
253 // Use 4x4 input as an example, reorder the frequency matrix from
254 // [[(F(0, 0), F(0, -2/4)) F(0, -1/4), 0],
255 // [ F(-1/4, 0), F(-1/4, -1/4), 0],
256 // [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0],
257 // [ j*F(-3/4, -2/4), F(-3/4, -1/4), 0]]
258 // to
259 // [[F(0, 0), F(0, -1/4), F(0, -2/4)],
260 // [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)],
261 // [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)],
262 // [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]]
263 for (int i = fft_height_half + 1; i < fft_height; ++i) {
264 real = fft_input_output[i][0];
265 img = fft_input_output[i][1];
266 fft_input_output[i][fft_width] = img;
267 fft_input_output[i][fft_width + 1] = real;
268 fft_input_output[fft_height - i][fft_width] = img;
269 fft_input_output[fft_height - i][fft_width + 1] = -real;
270 fft_input_output[i][0] = fft_input_output[fft_height - i][0];
271 fft_input_output[i][1] = -fft_input_output[fft_height - i][1];
272 }
273
274 double temp = fft_input_output[0][1];
275 fft_input_output[0][fft_width + 1] = 0;
276 fft_input_output[0][1] = 0;
277 fft_input_output[fft_height_half][fft_width] =
278 fft_input_output[fft_height_half][1];
279 fft_input_output[fft_height_half][fft_width + 1] = 0;
280 fft_input_output[fft_height_half][1] = 0;
281 fft_input_output[0][fft_width] = temp;
282
283 // Reorder the frequency matrix from
284 // [[F(0, 0), F(0, -1/4), F(0, -2/4)],
285 // [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)],
286 // [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)],
287 // [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]]
288 // to
289 // [[F(0, 0), F(0, 1/4), F(0, 2/4)],
290 // [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)],
291 // [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)],
292 // [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]]
293 for (int i = 0; i < fft_height; ++i) {
294 for (int j = 1; j < fft_width + 2; j += 2) {
295 fft_input_output[i][j] = -fft_input_output[i][j];
296 }
297 }
298}
299
300void Rfft2dImpl(int fft_height, int fft_width, double** fft_input_output,
301 int* fft_integer_working_area_data,
302 double* fft_double_working_area_data) {
303 ruy::profiler::ScopeLabel label("Rfft2dImpl");
304
305 // Working data areas for the FFT routines.
306 double* fft_dynamic_working_area = nullptr;
307 const int kForwardFft = 1;
308 rdft2d(fft_height, fft_width, kForwardFft, fft_input_output,
309 fft_dynamic_working_area, fft_integer_working_area_data,
310 fft_double_working_area_data);
311 Rfft2dReorder(fft_height, fft_width, fft_input_output);
312}
313
314void PrepareInputBuffer(const float* input_data, int input_height,
315 int input_width, int fft_height, int fft_width,
316 double** fft_input_output) {
317 int valid_input_height = std::min(input_height, fft_height);
318 int valid_input_width = std::min(input_width, fft_width);
319 for (int i = 0; i < valid_input_height; ++i) {
320 int in_pos = i * input_width;
321 for (int j = 0; j < valid_input_width; ++j) {
322 fft_input_output[i][j] = input_data[in_pos++];
323 }
324 // Zero-pad the rest of the input buffer
325 for (int j = valid_input_width; j < fft_width + 2; ++j) {
326 fft_input_output[i][j] = 0;
327 }
328 }
329
330 // Zero-pad input buffer, if fft_height is greater than valid_input_height.
331 for (int i = valid_input_height; i < fft_height; ++i) {
332 for (int j = 0; j < fft_width + 2; ++j) {
333 fft_input_output[i][j] = 0;
334 }
335 }
336}
337
338void PrepareOutputBuffer(complex<float>* output_data, int fft_height,
339 int fft_width, double** fft_input_output) {
340 int cnt = 0;
341 for (int i = 0; i < fft_height; ++i) {
342 for (int j = 0; j < fft_width / 2 + 1; ++j) {
343 output_data[cnt++] = complex<float>(fft_input_output[i][j * 2],
344 fft_input_output[i][j * 2 + 1]);
345 }
346 }
347}
348
349TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) {
350 const TfLiteTensor* input;
351 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
352 const float* input_data = GetTensorData<float>(input);
353 const TfLiteTensor* fft_length;
354 TF_LITE_ENSURE_OK(context,
355 GetInputSafe(context, node, kFftLengthTensor, &fft_length));
356 const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
357 TfLiteTensor* output;
358 TF_LITE_ENSURE_OK(context,
359 GetOutputSafe(context, node, kOutputTensor, &output));
360 complex<float>* output_data = GetTensorData<complex<float>>(output);
361
362 int fft_height, fft_width;
363 fft_height = fft_length_data[0];
364 fft_width = fft_length_data[1];
365
366 // FFT is processed for every slice on the inner most 2 dimensions.
367 // Count the number of slices in the input tensor.
368 const RuntimeShape input_shape = GetTensorShape(input);
369 const int input_dims_count = input_shape.DimensionsCount();
370 const auto* input_dims_data = input_shape.DimsData();
371 int num_slices = 1;
372 for (int i = 0; i < input_dims_count - 2; ++i) {
373 num_slices *= input_dims_data[i];
374 }
375
376 int input_height = input_dims_data[input_dims_count - 2];
377 int input_width = input_dims_data[input_dims_count - 1];
378 int input_slice_size = input_height * input_width;
379 int output_slice_size = fft_height * (fft_width / 2 + 1);
380
381 // Create input/output buffer for FFT
382 double** fft_input_output = new double*[fft_height];
383 for (int i = 0; i < fft_height; ++i) {
384 fft_input_output[i] = new double[fft_width + 2];
385 }
386
387 // Get buffer for integer working area.
388 TfLiteTensor* fft_integer_working_area;
389 TF_LITE_ENSURE_OK(
390 context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor,
391 &fft_integer_working_area));
392 int* fft_integer_working_area_data =
393 GetTensorData<int>(fft_integer_working_area);
394
395 // Get buffer for double working area.
396 TfLiteTensor* fft_double_working_area;
397 TF_LITE_ENSURE_OK(context,
398 GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor,
399 &fft_double_working_area));
400 // Get double value out of the memory of fft_double_working_area_data.
401 double* fft_double_working_area_data = reinterpret_cast<double*>(
402 GetTensorData<int64_t>(fft_double_working_area));
403
404 // Process every slice in the input buffer
405 for (int i = 0; i < num_slices; ++i) {
406 PrepareInputBuffer(input_data, input_height, input_width, fft_height,
407 fft_width, fft_input_output);
408 memset(fft_integer_working_area_data, 0, fft_integer_working_area->bytes);
409 memset(fft_double_working_area_data, 0, fft_double_working_area->bytes);
410 Rfft2dImpl(fft_height, fft_width, fft_input_output,
411 fft_integer_working_area_data, fft_double_working_area_data);
412 PrepareOutputBuffer(output_data, fft_height, fft_width, fft_input_output);
413 input_data += input_slice_size;
414 output_data += output_slice_size;
415 }
416
417 // Delete the input buffer
418 for (int i = 0; i < fft_height; ++i) {
419 delete[] fft_input_output[i];
420 }
421 delete[] fft_input_output;
422
423 return kTfLiteOk;
424}
425
426TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
427 const TfLiteTensor* input;
428 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
429 const TfLiteTensor* fft_length;
430 TF_LITE_ENSURE_OK(context,
431 GetInputSafe(context, node, kFftLengthTensor, &fft_length));
432 const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length);
433 TfLiteTensor* output;
434 TF_LITE_ENSURE_OK(context,
435 GetOutputSafe(context, node, kOutputTensor, &output));
436
437 if (output->type != kTfLiteComplex64) {
438 TF_LITE_KERNEL_LOG(context,
439 "Type '%s' for output is not supported by rfft2d.",
440 TfLiteTypeGetName(output->type));
441 return kTfLiteError;
442 }
443
444 // Resize the output tensor if the fft_length tensor is not constant.
445 // Otherwise, check if the output shape is correct.
446 if (!IsConstantTensor(fft_length)) {
447 TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node));
448 } else {
449 int num_dims_output = NumDimensions(output);
450 const RuntimeShape output_shape = GetTensorShape(output);
451 TF_LITE_ENSURE_EQ(context, num_dims_output, NumDimensions(input));
452 TF_LITE_ENSURE(context, num_dims_output >= 2);
453 TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 2),
454 fft_length_data[0]);
455 TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 1),
456 fft_length_data[1] / 2 + 1);
457 }
458
459 return Rfft2dHelper(context, node);
460}
461
462} // namespace rfft2d
463
464TfLiteRegistration* Register_RFFT2D() {
465 static TfLiteRegistration r = {rfft2d::Init, rfft2d::Free, rfft2d::Prepare,
466 rfft2d::Eval};
467 return &r;
468}
469
470} // namespace builtin
471} // namespace ops
472} // namespace tflite
473