1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/numeric_op.h" |
18 | #include "tensorflow/core/framework/op.h" |
19 | #include "tensorflow/core/framework/shape_inference.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | using shape_inference::DimensionHandle; |
24 | using shape_inference::InferenceContext; |
25 | using shape_inference::ShapeHandle; |
26 | |
27 | REGISTER_OP("FFT" ) |
28 | .Input("input: Tcomplex" ) |
29 | .Output("output: Tcomplex" ) |
30 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
31 | .SetShapeFn([](InferenceContext* c) { |
32 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); |
33 | }); |
34 | |
35 | REGISTER_OP("IFFT" ) |
36 | .Input("input: Tcomplex" ) |
37 | .Output("output: Tcomplex" ) |
38 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
39 | .SetShapeFn([](InferenceContext* c) { |
40 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); |
41 | }); |
42 | |
43 | REGISTER_OP("FFT2D" ) |
44 | .Input("input: Tcomplex" ) |
45 | .Output("output: Tcomplex" ) |
46 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
47 | .SetShapeFn([](InferenceContext* c) { |
48 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 2); |
49 | }); |
50 | |
51 | REGISTER_OP("IFFT2D" ) |
52 | .Input("input: Tcomplex" ) |
53 | .Output("output: Tcomplex" ) |
54 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
55 | .SetShapeFn([](InferenceContext* c) { |
56 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 2); |
57 | }); |
58 | |
59 | REGISTER_OP("FFT3D" ) |
60 | .Input("input: Tcomplex" ) |
61 | .Output("output: Tcomplex" ) |
62 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
63 | .SetShapeFn([](InferenceContext* c) { |
64 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); |
65 | }); |
66 | |
67 | REGISTER_OP("IFFT3D" ) |
68 | .Input("input: Tcomplex" ) |
69 | .Output("output: Tcomplex" ) |
70 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
71 | .SetShapeFn([](InferenceContext* c) { |
72 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 3); |
73 | }); |
74 | |
75 | Status RFFTShape(InferenceContext* c, const bool forward, const int rank) { |
76 | ShapeHandle out; |
77 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out)); |
78 | |
79 | // Check that fft_length has shape [rank]. |
80 | ShapeHandle unused_shape; |
81 | DimensionHandle unused_dim; |
82 | ShapeHandle fft_length_input = c->input(1); |
83 | TF_RETURN_IF_ERROR(c->WithRank(fft_length_input, 1, &unused_shape)); |
84 | TF_RETURN_IF_ERROR( |
85 | c->WithValue(c->Dim(fft_length_input, 0), rank, &unused_dim)); |
86 | const Tensor* fft_length_tensor = c->input_tensor(1); |
87 | |
88 | // If fft_length is unknown at graph creation time, we can't predict the |
89 | // output size. |
90 | if (fft_length_tensor == nullptr) { |
91 | // We can't know the dimension of any of the rank inner dimensions of the |
92 | // output without knowing fft_length. |
93 | for (int i = 0; i < rank; ++i) { |
94 | TF_RETURN_IF_ERROR(c->ReplaceDim(out, -rank + i, c->UnknownDim(), &out)); |
95 | } |
96 | } else { |
97 | auto fft_length_as_vec = fft_length_tensor->vec<int32>(); |
98 | for (int i = 0; i < rank; ++i) { |
99 | // For RFFT, replace the last dimension with fft_length/2 + 1. |
100 | auto dim = forward && i == rank - 1 && fft_length_as_vec(i) != 0 |
101 | ? fft_length_as_vec(i) / 2 + 1 |
102 | : fft_length_as_vec(i); |
103 | TF_RETURN_IF_ERROR(c->ReplaceDim(out, -rank + i, c->MakeDim(dim), &out)); |
104 | } |
105 | } |
106 | |
107 | c->set_output(0, out); |
108 | return OkStatus(); |
109 | } |
110 | |
111 | REGISTER_OP("RFFT" ) |
112 | .Input("input: Treal" ) |
113 | .Input("fft_length: int32" ) |
114 | .Output("output: Tcomplex" ) |
115 | .Attr("Treal: {float32, float64} = DT_FLOAT" ) |
116 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
117 | .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 1); }); |
118 | |
119 | REGISTER_OP("IRFFT" ) |
120 | .Input("input: Tcomplex" ) |
121 | .Input("fft_length: int32" ) |
122 | .Output("output: Treal" ) |
123 | .Attr("Treal: {float32, float64} = DT_FLOAT" ) |
124 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
125 | .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 1); }); |
126 | |
127 | REGISTER_OP("RFFT2D" ) |
128 | .Input("input: Treal" ) |
129 | .Input("fft_length: int32" ) |
130 | .Output("output: Tcomplex" ) |
131 | .Attr("Treal: {float32, float64} = DT_FLOAT" ) |
132 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
133 | .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 2); }); |
134 | |
135 | REGISTER_OP("IRFFT2D" ) |
136 | .Input("input: Tcomplex" ) |
137 | .Input("fft_length: int32" ) |
138 | .Output("output: Treal" ) |
139 | .Attr("Treal: {float32, float64} = DT_FLOAT" ) |
140 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
141 | .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 2); }); |
142 | |
143 | REGISTER_OP("RFFT3D" ) |
144 | .Input("input: Treal" ) |
145 | .Input("fft_length: int32" ) |
146 | .Output("output: Tcomplex" ) |
147 | .Attr("Treal: {float32, float64} = DT_FLOAT" ) |
148 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
149 | .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, true, 3); }); |
150 | |
151 | REGISTER_OP("IRFFT3D" ) |
152 | .Input("input: Tcomplex" ) |
153 | .Input("fft_length: int32" ) |
154 | .Output("output: Treal" ) |
155 | .Attr("Treal: {float32, float64} = DT_FLOAT" ) |
156 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
157 | .SetShapeFn([](InferenceContext* c) { return RFFTShape(c, false, 3); }); |
158 | |
159 | // Deprecated ops: |
160 | REGISTER_OP("BatchFFT" ) |
161 | .Input("input: complex64" ) |
162 | .Output("output: complex64" ) |
163 | .SetShapeFn(shape_inference::UnknownShape) |
164 | .Deprecated(15, "Use FFT" ); |
165 | REGISTER_OP("BatchIFFT" ) |
166 | .Input("input: complex64" ) |
167 | .Output("output: complex64" ) |
168 | .SetShapeFn(shape_inference::UnknownShape) |
169 | .Deprecated(15, "Use IFFT" ); |
170 | REGISTER_OP("BatchFFT2D" ) |
171 | .Input("input: complex64" ) |
172 | .Output("output: complex64" ) |
173 | .SetShapeFn(shape_inference::UnknownShape) |
174 | .Deprecated(15, "Use FFT2D" ); |
175 | REGISTER_OP("BatchIFFT2D" ) |
176 | .Input("input: complex64" ) |
177 | .Output("output: complex64" ) |
178 | .SetShapeFn(shape_inference::UnknownShape) |
179 | .Deprecated(15, "Use IFFT2D" ); |
180 | REGISTER_OP("BatchFFT3D" ) |
181 | .Input("input: complex64" ) |
182 | .Output("output: complex64" ) |
183 | .SetShapeFn(shape_inference::UnknownShape) |
184 | .Deprecated(15, "Use FFT3D" ); |
185 | REGISTER_OP("BatchIFFT3D" ) |
186 | .Input("input: complex64" ) |
187 | .Output("output: complex64" ) |
188 | .SetShapeFn(shape_inference::UnknownShape) |
189 | .Deprecated(15, "Use IFFT3D" ); |
190 | |
191 | } // namespace tensorflow |
192 | |