1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <complex>
17
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/tensor.h"
20#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
21#include "tensorflow/lite/kernels/kernel_util.h"
22
23namespace tflite {
24namespace ops {
25namespace builtin {
26namespace complex {
27
28static const int kInputTensor = 0;
29static const int kOutputTensor = 0;
30
31TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
32 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
33 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
34
35 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
36
37 TF_LITE_ENSURE(context, input->type == kTfLiteComplex64 ||
38 input->type == kTfLiteComplex128);
39
40 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
41
42 if (input->type == kTfLiteComplex64) {
43 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
44 } else {
45 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat64);
46 }
47
48 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
49 return context->ResizeTensor(context, output, output_shape);
50}
51
52template <typename T, typename ExtractF>
53void ExtractData(const TfLiteTensor* input, ExtractF extract_func,
54 TfLiteTensor* output) {
55 const std::complex<T>* input_data = GetTensorData<std::complex<T>>(input);
56 T* output_data = GetTensorData<T>(output);
57 const int input_size = NumElements(input);
58 for (int i = 0; i < input_size; ++i) {
59 *output_data++ = extract_func(*input_data++);
60 }
61}
62
63TfLiteStatus EvalReal(TfLiteContext* context, TfLiteNode* node) {
64 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
65
66 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
67
68 switch (input->type) {
69 case kTfLiteComplex64: {
70 ExtractData<float>(
71 input,
72 static_cast<float (*)(const std::complex<float>&)>(std::real<float>),
73 output);
74 break;
75 }
76 case kTfLiteComplex128: {
77 ExtractData<double>(input,
78 static_cast<double (*)(const std::complex<double>&)>(
79 std::real<double>),
80 output);
81 break;
82 }
83 default: {
84 TF_LITE_KERNEL_LOG(context,
85 "Unsupported input type, Real op only supports "
86 "complex input, but got: ",
87 TfLiteTypeGetName(input->type));
88 return kTfLiteError;
89 }
90 }
91
92 return kTfLiteOk;
93}
94
95TfLiteStatus EvalImag(TfLiteContext* context, TfLiteNode* node) {
96 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
97
98 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
99
100 switch (input->type) {
101 case kTfLiteComplex64: {
102 ExtractData<float>(
103 input,
104 static_cast<float (*)(const std::complex<float>&)>(std::imag<float>),
105 output);
106 break;
107 }
108 case kTfLiteComplex128: {
109 ExtractData<double>(input,
110 static_cast<double (*)(const std::complex<double>&)>(
111 std::imag<double>),
112 output);
113 break;
114 }
115 default: {
116 TF_LITE_KERNEL_LOG(context,
117 "Unsupported input type, Imag op only supports "
118 "complex input, but got: ",
119 TfLiteTypeGetName(input->type));
120 return kTfLiteError;
121 }
122 }
123
124 return kTfLiteOk;
125}
126
127TfLiteStatus EvalAbs(TfLiteContext* context, TfLiteNode* node) {
128 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
129 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
130
131 switch (input->type) {
132 case kTfLiteComplex64: {
133 ExtractData<float>(
134 input,
135 static_cast<float (*)(const std::complex<float>&)>(std::abs<float>),
136 output);
137 break;
138 }
139 case kTfLiteComplex128: {
140 ExtractData<double>(input,
141 static_cast<double (*)(const std::complex<double>&)>(
142 std::abs<double>),
143 output);
144 break;
145 }
146 default: {
147 TF_LITE_KERNEL_LOG(context,
148 "Unsupported input type, ComplexAbs op only supports "
149 "complex input, but got: ",
150 TfLiteTypeGetName(input->type));
151 return kTfLiteError;
152 }
153 }
154
155 return kTfLiteOk;
156}
157
158} // namespace complex
159
160TfLiteRegistration* Register_REAL() {
161 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
162 complex::Prepare, complex::EvalReal};
163 return &r;
164}
165
166TfLiteRegistration* Register_IMAG() {
167 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
168 complex::Prepare, complex::EvalImag};
169 return &r;
170}
171
172TfLiteRegistration* Register_COMPLEX_ABS() {
173 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
174 complex::Prepare, complex::EvalAbs};
175 return &r;
176}
177
178} // namespace builtin
179} // namespace ops
180} // namespace tflite
181