1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <math.h> |
17 | #include <stddef.h> |
18 | #include <stdint.h> |
19 | #include <string.h> |
20 | |
21 | #include <algorithm> |
22 | #include <complex> |
23 | |
24 | #include "third_party/fft2d/fft2d.h" |
25 | #include "ruy/profiler/instrumentation.h" // from @ruy |
26 | #include "tensorflow/lite/c/common.h" |
27 | #include "tensorflow/lite/kernels/internal/tensor.h" |
28 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
29 | #include "tensorflow/lite/kernels/internal/types.h" |
30 | #include "tensorflow/lite/kernels/kernel_util.h" |
31 | |
32 | namespace tflite { |
33 | namespace ops { |
34 | namespace builtin { |
35 | namespace rfft2d { |
36 | |
37 | using std::complex; |
38 | |
39 | constexpr int kInputTensor = 0; |
40 | constexpr int kFftLengthTensor = 1; |
41 | constexpr int kOutputTensor = 0; |
42 | constexpr int kFftIntegerWorkingAreaTensor = 0; |
43 | constexpr int kFftDoubleWorkingAreaTensor = 1; |
44 | constexpr int kTensorNotAllocated = -1; |
45 | |
46 | struct OpData { |
47 | // IDs are the arbitrary identifiers used by TF Lite to identify and access |
48 | // memory buffers. |
49 | int fft_integer_working_area_id = kTensorNotAllocated; |
50 | int fft_double_working_area_id = kTensorNotAllocated; |
51 | }; |
52 | |
53 | bool IsPowerOfTwo(uint32_t v) { return v && !(v & (v - 1)); } |
54 | |
55 | static TfLiteStatus InitTemporaryTensors(TfLiteContext* context, |
56 | TfLiteNode* node) { |
57 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
58 | // The prepare function may be executed multiple times. But temporary tensors |
59 | // only need to be initiated once. |
60 | if (data->fft_integer_working_area_id != kTensorNotAllocated && |
61 | data->fft_double_working_area_id != kTensorNotAllocated) { |
62 | return kTfLiteOk; |
63 | } |
64 | |
65 | TfLiteIntArrayFree(node->temporaries); |
66 | // Create two temporary tensors. |
67 | node->temporaries = TfLiteIntArrayCreate(2); |
68 | int first_new_index; |
69 | TF_LITE_ENSURE_STATUS(context->AddTensors(context, 2, &first_new_index)); |
70 | node->temporaries->data[kFftIntegerWorkingAreaTensor] = first_new_index; |
71 | data->fft_integer_working_area_id = first_new_index; |
72 | node->temporaries->data[kFftDoubleWorkingAreaTensor] = first_new_index + 1; |
73 | data->fft_double_working_area_id = first_new_index + 1; |
74 | |
75 | // Set up FFT integer working area buffer. |
76 | TfLiteTensor* fft_integer_working_area; |
77 | TF_LITE_ENSURE_OK( |
78 | context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor, |
79 | &fft_integer_working_area)); |
80 | fft_integer_working_area->type = kTfLiteInt32; |
81 | // If fft_length is not a constant tensor, fft_integer_working_area will be |
82 | // set to dynamic later in Prepare. |
83 | fft_integer_working_area->allocation_type = kTfLiteArenaRw; |
84 | |
85 | // Set up FFT double working area buffer. |
86 | TfLiteTensor* fft_double_working_area; |
87 | TF_LITE_ENSURE_OK(context, |
88 | GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor, |
89 | &fft_double_working_area)); |
90 | // fft_double_working_area is a double tensor. Ideally, double should be |
91 | // added into tflite data types. However, since fft_double_working_area is a |
92 | // temporary tensor, and there are no ops having double input/output tensors |
93 | // in tflite at this point, adding double as a tflite data type may confuse |
94 | // users that double is supported. As a results, kTfLiteInt64 is used here |
95 | // for memory allocation. And it will be cast into double in Eval when being |
96 | // used. |
97 | fft_double_working_area->type = kTfLiteInt64; |
98 | // If fft_length is not a constant tensor, fft_double_working_area will be |
99 | // set to dynamic later in Prepare. |
100 | fft_double_working_area->allocation_type = kTfLiteArenaRw; |
101 | |
102 | return kTfLiteOk; |
103 | } |
104 | |
105 | TfLiteStatus ResizeOutputandTemporaryTensors(TfLiteContext* context, |
106 | TfLiteNode* node) { |
107 | const TfLiteTensor* input; |
108 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
109 | const int num_dims = NumDimensions(input); |
110 | TF_LITE_ENSURE(context, num_dims >= 2); |
111 | const TfLiteTensor* fft_length; |
112 | TF_LITE_ENSURE_OK(context, |
113 | GetInputSafe(context, node, kFftLengthTensor, &fft_length)); |
114 | const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length); |
115 | // The lib, fft2d, can only handle fft_lengths of power of 2. |
116 | TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[0])); |
117 | TF_LITE_ENSURE(context, IsPowerOfTwo(fft_length_data[1])); |
118 | |
119 | int fft_height, fft_width; |
120 | fft_height = fft_length_data[0]; |
121 | fft_width = fft_length_data[1]; |
122 | int fft_working_length = std::max(fft_height, fft_width / 2); |
123 | int half_fft_working_length = fft_working_length / 2; |
124 | |
125 | // Resize output tensor. |
126 | TfLiteTensor* output; |
127 | TF_LITE_ENSURE_OK(context, |
128 | GetOutputSafe(context, node, kOutputTensor, &output)); |
129 | TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims); |
130 | output_shape->data[num_dims - 2] = fft_length_data[0]; |
131 | output_shape->data[num_dims - 1] = fft_length_data[1] / 2 + 1; |
132 | TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape)); |
133 | |
134 | // Resize temporary tensors, fft_integer_working_area. |
135 | TfLiteTensor* fft_integer_working_area; |
136 | TF_LITE_ENSURE_OK( |
137 | context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor, |
138 | &fft_integer_working_area)); |
139 | TfLiteIntArray* fft_integer_working_area_shape = TfLiteIntArrayCreate(1); |
140 | fft_integer_working_area_shape->data[0] = |
141 | 2 + static_cast<int>(sqrt(fft_working_length)); |
142 | TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_integer_working_area, |
143 | fft_integer_working_area_shape)); |
144 | |
145 | // Resize temporary tensors, fft_double_working_area. |
146 | TfLiteTensor* fft_double_working_area; |
147 | TF_LITE_ENSURE_OK(context, |
148 | GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor, |
149 | &fft_double_working_area)); |
150 | TfLiteIntArray* fft_double_working_area_shape = TfLiteIntArrayCreate(1); |
151 | fft_double_working_area_shape->data[0] = |
152 | half_fft_working_length + fft_width / 4; |
153 | TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, fft_double_working_area, |
154 | fft_double_working_area_shape)); |
155 | |
156 | return kTfLiteOk; |
157 | } |
158 | |
159 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
160 | auto* data = new OpData; |
161 | return data; |
162 | } |
163 | |
164 | void Free(TfLiteContext* context, void* buffer) { |
165 | delete reinterpret_cast<OpData*>(buffer); |
166 | } |
167 | |
168 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
169 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
170 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
171 | |
172 | // Check type and shape of the input tensor |
173 | const TfLiteTensor* input; |
174 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
175 | TF_LITE_ENSURE(context, NumDimensions(input) >= 2); |
176 | if (input->type != kTfLiteFloat32) { |
177 | TF_LITE_KERNEL_LOG(context, |
178 | "Type '%s' for input is not supported by rfft2d." , |
179 | TfLiteTypeGetName(input->type)); |
180 | return kTfLiteError; |
181 | } |
182 | |
183 | // Check type and shape of the fft_length tensor |
184 | const TfLiteTensor* fft_length; |
185 | TF_LITE_ENSURE_OK(context, |
186 | GetInputSafe(context, node, kFftLengthTensor, &fft_length)); |
187 | const RuntimeShape fft_length_shape = GetTensorShape(fft_length); |
188 | |
189 | TF_LITE_ENSURE_EQ(context, NumDimensions(fft_length), 1); |
190 | TF_LITE_ENSURE_EQ(context, fft_length_shape.Dims(0), 2); |
191 | if (fft_length->type != kTfLiteInt32) { |
192 | TF_LITE_KERNEL_LOG(context, |
193 | "Type '%s' for fft_length is not supported by rfft2d." , |
194 | TfLiteTypeGetName(fft_length->type)); |
195 | return kTfLiteError; |
196 | } |
197 | |
198 | // Setup temporary tensors for fft computation. |
199 | TF_LITE_ENSURE_STATUS(InitTemporaryTensors(context, node)); |
200 | |
201 | // Set output type |
202 | TfLiteTensor* output; |
203 | TF_LITE_ENSURE_OK(context, |
204 | GetOutputSafe(context, node, kOutputTensor, &output)); |
205 | output->type = kTfLiteComplex64; |
206 | |
207 | // Exit early if fft_length is a non-const tensor. Set output tensor and |
208 | // temporary tensors to dynamic, so that their tensor sizes can be determined |
209 | // in Eval. |
210 | if (!IsConstantTensor(fft_length)) { |
211 | TfLiteTensor* fft_integer_working_area; |
212 | TF_LITE_ENSURE_OK( |
213 | context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor, |
214 | &fft_integer_working_area)); |
215 | TfLiteTensor* fft_double_working_area; |
216 | TF_LITE_ENSURE_OK( |
217 | context, GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor, |
218 | &fft_double_working_area)); |
219 | SetTensorToDynamic(fft_integer_working_area); |
220 | SetTensorToDynamic(fft_double_working_area); |
221 | SetTensorToDynamic(output); |
222 | return kTfLiteOk; |
223 | } |
224 | |
225 | TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node)); |
226 | return kTfLiteOk; |
227 | } |
228 | |
229 | // Reorder the result so that it matches the pattern of tf.signal.rfft2d. |
230 | // In tf.signal.fft2d the frequency matrix of a 4x4 input is |
231 | // [[F(0, 0), F(0, 1/4), F(0, 2/4)], |
232 | // [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)], |
233 | // [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)], |
234 | // [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]] |
235 | // While in rdft2d, the frequency matrix of a 4x4 input is |
236 | // [[(F(0, 0), F(0, -2/4)) F(0, -1/4), 0], |
237 | // [ F(-1/4, 0), F(-1/4, -1/4), 0], |
238 | // [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0], |
239 | // [ j*F(-3/4, -2/4), F(-3/4, -1/4), 0]] |
240 | // Since real fft has the property that |
241 | // Real(u,v) = Real(-u, -v) |
242 | // Img(u,v) = - Img(-u, -v) |
243 | // Result of rdft2d can be reordered and match the pattern of tf.signal.rfft2d. |
244 | // For example, |
245 | // Real(-3/4, 0) = Real(1/4, 0) = Real(-1/4, 0) |
246 | // Img(-3/4, 0) = Img(1/4, 0) = -Img(-1/4, 0) |
247 | void Rfft2dReorder(int fft_height, int fft_width, double** fft_input_output) { |
248 | int fft_height_half; |
249 | ruy::profiler::ScopeLabel label("Rfft2dReorder" ); |
250 | double real, img; |
251 | |
252 | fft_height_half = fft_height >> 1; |
253 | // Use 4x4 input as an example, reorder the frequency matrix from |
254 | // [[(F(0, 0), F(0, -2/4)) F(0, -1/4), 0], |
255 | // [ F(-1/4, 0), F(-1/4, -1/4), 0], |
256 | // [(F(-2/4, 0),F(-2/4, -2/4)), F(-2/4, -1/4), 0], |
257 | // [ j*F(-3/4, -2/4), F(-3/4, -1/4), 0]] |
258 | // to |
259 | // [[F(0, 0), F(0, -1/4), F(0, -2/4)], |
260 | // [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)], |
261 | // [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)], |
262 | // [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]] |
263 | for (int i = fft_height_half + 1; i < fft_height; ++i) { |
264 | real = fft_input_output[i][0]; |
265 | img = fft_input_output[i][1]; |
266 | fft_input_output[i][fft_width] = img; |
267 | fft_input_output[i][fft_width + 1] = real; |
268 | fft_input_output[fft_height - i][fft_width] = img; |
269 | fft_input_output[fft_height - i][fft_width + 1] = -real; |
270 | fft_input_output[i][0] = fft_input_output[fft_height - i][0]; |
271 | fft_input_output[i][1] = -fft_input_output[fft_height - i][1]; |
272 | } |
273 | |
274 | double temp = fft_input_output[0][1]; |
275 | fft_input_output[0][fft_width + 1] = 0; |
276 | fft_input_output[0][1] = 0; |
277 | fft_input_output[fft_height_half][fft_width] = |
278 | fft_input_output[fft_height_half][1]; |
279 | fft_input_output[fft_height_half][fft_width + 1] = 0; |
280 | fft_input_output[fft_height_half][1] = 0; |
281 | fft_input_output[0][fft_width] = temp; |
282 | |
283 | // Reorder the frequency matrix from |
284 | // [[F(0, 0), F(0, -1/4), F(0, -2/4)], |
285 | // [F(-1/4, 0), F(-1/4, -1/4), F(-1/4, -2/4)], |
286 | // [F(-2/4, 0), F(-2/4, -1/4), F(-2/4, -2/4)], |
287 | // [F(-3/4, 0), F(-3/4, -1/4), F(-3/4, -2/4)]] |
288 | // to |
289 | // [[F(0, 0), F(0, 1/4), F(0, 2/4)], |
290 | // [F(1/4, 0), F(1/4, 1/4), F(1/4, 2/4)], |
291 | // [F(2/4, 0), F(2/4, 1/4), F(2/4, 2/4)], |
292 | // [F(3/4, 0), F(3/4, 1/4), F(3/4, 2/4)]] |
293 | for (int i = 0; i < fft_height; ++i) { |
294 | for (int j = 1; j < fft_width + 2; j += 2) { |
295 | fft_input_output[i][j] = -fft_input_output[i][j]; |
296 | } |
297 | } |
298 | } |
299 | |
300 | void Rfft2dImpl(int fft_height, int fft_width, double** fft_input_output, |
301 | int* fft_integer_working_area_data, |
302 | double* fft_double_working_area_data) { |
303 | ruy::profiler::ScopeLabel label("Rfft2dImpl" ); |
304 | |
305 | // Working data areas for the FFT routines. |
306 | double* fft_dynamic_working_area = nullptr; |
307 | const int kForwardFft = 1; |
308 | rdft2d(fft_height, fft_width, kForwardFft, fft_input_output, |
309 | fft_dynamic_working_area, fft_integer_working_area_data, |
310 | fft_double_working_area_data); |
311 | Rfft2dReorder(fft_height, fft_width, fft_input_output); |
312 | } |
313 | |
314 | void PrepareInputBuffer(const float* input_data, int input_height, |
315 | int input_width, int fft_height, int fft_width, |
316 | double** fft_input_output) { |
317 | int valid_input_height = std::min(input_height, fft_height); |
318 | int valid_input_width = std::min(input_width, fft_width); |
319 | for (int i = 0; i < valid_input_height; ++i) { |
320 | int in_pos = i * input_width; |
321 | for (int j = 0; j < valid_input_width; ++j) { |
322 | fft_input_output[i][j] = input_data[in_pos++]; |
323 | } |
324 | // Zero-pad the rest of the input buffer |
325 | for (int j = valid_input_width; j < fft_width + 2; ++j) { |
326 | fft_input_output[i][j] = 0; |
327 | } |
328 | } |
329 | |
330 | // Zero-pad input buffer, if fft_height is greater than valid_input_height. |
331 | for (int i = valid_input_height; i < fft_height; ++i) { |
332 | for (int j = 0; j < fft_width + 2; ++j) { |
333 | fft_input_output[i][j] = 0; |
334 | } |
335 | } |
336 | } |
337 | |
338 | void PrepareOutputBuffer(complex<float>* output_data, int fft_height, |
339 | int fft_width, double** fft_input_output) { |
340 | int cnt = 0; |
341 | for (int i = 0; i < fft_height; ++i) { |
342 | for (int j = 0; j < fft_width / 2 + 1; ++j) { |
343 | output_data[cnt++] = complex<float>(fft_input_output[i][j * 2], |
344 | fft_input_output[i][j * 2 + 1]); |
345 | } |
346 | } |
347 | } |
348 | |
349 | TfLiteStatus Rfft2dHelper(TfLiteContext* context, TfLiteNode* node) { |
350 | const TfLiteTensor* input; |
351 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
352 | const float* input_data = GetTensorData<float>(input); |
353 | const TfLiteTensor* fft_length; |
354 | TF_LITE_ENSURE_OK(context, |
355 | GetInputSafe(context, node, kFftLengthTensor, &fft_length)); |
356 | const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length); |
357 | TfLiteTensor* output; |
358 | TF_LITE_ENSURE_OK(context, |
359 | GetOutputSafe(context, node, kOutputTensor, &output)); |
360 | complex<float>* output_data = GetTensorData<complex<float>>(output); |
361 | |
362 | int fft_height, fft_width; |
363 | fft_height = fft_length_data[0]; |
364 | fft_width = fft_length_data[1]; |
365 | |
366 | // FFT is processed for every slice on the inner most 2 dimensions. |
367 | // Count the number of slices in the input tensor. |
368 | const RuntimeShape input_shape = GetTensorShape(input); |
369 | const int input_dims_count = input_shape.DimensionsCount(); |
370 | const auto* input_dims_data = input_shape.DimsData(); |
371 | int num_slices = 1; |
372 | for (int i = 0; i < input_dims_count - 2; ++i) { |
373 | num_slices *= input_dims_data[i]; |
374 | } |
375 | |
376 | int input_height = input_dims_data[input_dims_count - 2]; |
377 | int input_width = input_dims_data[input_dims_count - 1]; |
378 | int input_slice_size = input_height * input_width; |
379 | int output_slice_size = fft_height * (fft_width / 2 + 1); |
380 | |
381 | // Create input/output buffer for FFT |
382 | double** fft_input_output = new double*[fft_height]; |
383 | for (int i = 0; i < fft_height; ++i) { |
384 | fft_input_output[i] = new double[fft_width + 2]; |
385 | } |
386 | |
387 | // Get buffer for integer working area. |
388 | TfLiteTensor* fft_integer_working_area; |
389 | TF_LITE_ENSURE_OK( |
390 | context, GetTemporarySafe(context, node, kFftIntegerWorkingAreaTensor, |
391 | &fft_integer_working_area)); |
392 | int* fft_integer_working_area_data = |
393 | GetTensorData<int>(fft_integer_working_area); |
394 | |
395 | // Get buffer for double working area. |
396 | TfLiteTensor* fft_double_working_area; |
397 | TF_LITE_ENSURE_OK(context, |
398 | GetTemporarySafe(context, node, kFftDoubleWorkingAreaTensor, |
399 | &fft_double_working_area)); |
400 | // Get double value out of the memory of fft_double_working_area_data. |
401 | double* fft_double_working_area_data = reinterpret_cast<double*>( |
402 | GetTensorData<int64_t>(fft_double_working_area)); |
403 | |
404 | // Process every slice in the input buffer |
405 | for (int i = 0; i < num_slices; ++i) { |
406 | PrepareInputBuffer(input_data, input_height, input_width, fft_height, |
407 | fft_width, fft_input_output); |
408 | memset(fft_integer_working_area_data, 0, fft_integer_working_area->bytes); |
409 | memset(fft_double_working_area_data, 0, fft_double_working_area->bytes); |
410 | Rfft2dImpl(fft_height, fft_width, fft_input_output, |
411 | fft_integer_working_area_data, fft_double_working_area_data); |
412 | PrepareOutputBuffer(output_data, fft_height, fft_width, fft_input_output); |
413 | input_data += input_slice_size; |
414 | output_data += output_slice_size; |
415 | } |
416 | |
417 | // Delete the input buffer |
418 | for (int i = 0; i < fft_height; ++i) { |
419 | delete[] fft_input_output[i]; |
420 | } |
421 | delete[] fft_input_output; |
422 | |
423 | return kTfLiteOk; |
424 | } |
425 | |
426 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
427 | const TfLiteTensor* input; |
428 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
429 | const TfLiteTensor* fft_length; |
430 | TF_LITE_ENSURE_OK(context, |
431 | GetInputSafe(context, node, kFftLengthTensor, &fft_length)); |
432 | const int32_t* fft_length_data = GetTensorData<int32_t>(fft_length); |
433 | TfLiteTensor* output; |
434 | TF_LITE_ENSURE_OK(context, |
435 | GetOutputSafe(context, node, kOutputTensor, &output)); |
436 | |
437 | if (output->type != kTfLiteComplex64) { |
438 | TF_LITE_KERNEL_LOG(context, |
439 | "Type '%s' for output is not supported by rfft2d." , |
440 | TfLiteTypeGetName(output->type)); |
441 | return kTfLiteError; |
442 | } |
443 | |
444 | // Resize the output tensor if the fft_length tensor is not constant. |
445 | // Otherwise, check if the output shape is correct. |
446 | if (!IsConstantTensor(fft_length)) { |
447 | TF_LITE_ENSURE_STATUS(ResizeOutputandTemporaryTensors(context, node)); |
448 | } else { |
449 | int num_dims_output = NumDimensions(output); |
450 | const RuntimeShape output_shape = GetTensorShape(output); |
451 | TF_LITE_ENSURE_EQ(context, num_dims_output, NumDimensions(input)); |
452 | TF_LITE_ENSURE(context, num_dims_output >= 2); |
453 | TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 2), |
454 | fft_length_data[0]); |
455 | TF_LITE_ENSURE_EQ(context, output_shape.Dims(num_dims_output - 1), |
456 | fft_length_data[1] / 2 + 1); |
457 | } |
458 | |
459 | return Rfft2dHelper(context, node); |
460 | } |
461 | |
462 | } // namespace rfft2d |
463 | |
464 | TfLiteRegistration* Register_RFFT2D() { |
465 | static TfLiteRegistration r = {rfft2d::Init, rfft2d::Free, rfft2d::Prepare, |
466 | rfft2d::Eval}; |
467 | return &r; |
468 | } |
469 | |
470 | } // namespace builtin |
471 | } // namespace ops |
472 | } // namespace tflite |
473 | |