1 | #pragma once |
2 | |
3 | #include <ATen/ATen.h> |
4 | |
5 | namespace torch { |
6 | namespace fft { |
7 | |
8 | /// Computes the 1 dimensional fast Fourier transform over a given dimension. |
9 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.fft. |
10 | /// |
11 | /// Example: |
12 | /// ``` |
13 | /// auto t = torch::randn(128, dtype=kComplexDouble); |
14 | /// torch::fft::fft(t); |
15 | /// ``` |
16 | inline Tensor fft( |
17 | const Tensor& self, |
18 | c10::optional<int64_t> n = c10::nullopt, |
19 | int64_t dim = -1, |
20 | c10::optional<c10::string_view> norm = c10::nullopt) { |
21 | return torch::fft_fft(self, n, dim, norm); |
22 | } |
23 | |
24 | /// Computes the 1 dimensional inverse Fourier transform over a given dimension. |
25 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.ifft. |
26 | /// |
27 | /// Example: |
28 | /// ``` |
29 | /// auto t = torch::randn(128, dtype=kComplexDouble); |
30 | /// torch::fft::ifft(t); |
31 | /// ``` |
32 | inline Tensor ifft( |
33 | const Tensor& self, |
34 | c10::optional<int64_t> n = c10::nullopt, |
35 | int64_t dim = -1, |
36 | c10::optional<c10::string_view> norm = c10::nullopt) { |
37 | return torch::fft_ifft(self, n, dim, norm); |
38 | } |
39 | |
40 | /// Computes the 2-dimensional fast Fourier transform over the given dimensions. |
41 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.fft2. |
42 | /// |
43 | /// Example: |
44 | /// ``` |
45 | /// auto t = torch::randn({128, 128}, dtype=kComplexDouble); |
46 | /// torch::fft::fft2(t); |
47 | /// ``` |
48 | inline Tensor fft2( |
49 | const Tensor& self, |
50 | OptionalIntArrayRef s = c10::nullopt, |
51 | IntArrayRef dim = {-2, -1}, |
52 | c10::optional<c10::string_view> norm = c10::nullopt) { |
53 | return torch::fft_fft2(self, s, dim, norm); |
54 | } |
55 | |
56 | /// Computes the inverse of torch.fft.fft2 |
57 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.ifft2. |
58 | /// |
59 | /// Example: |
60 | /// ``` |
61 | /// auto t = torch::randn({128, 128}, dtype=kComplexDouble); |
62 | /// torch::fft::ifft2(t); |
63 | /// ``` |
64 | inline Tensor ifft2( |
65 | const Tensor& self, |
66 | at::OptionalIntArrayRef s = c10::nullopt, |
67 | IntArrayRef dim = {-2, -1}, |
68 | c10::optional<c10::string_view> norm = c10::nullopt) { |
69 | return torch::fft_ifft2(self, s, dim, norm); |
70 | } |
71 | |
72 | /// Computes the N dimensional fast Fourier transform over given dimensions. |
73 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.fftn. |
74 | /// |
75 | /// Example: |
76 | /// ``` |
77 | /// auto t = torch::randn({128, 128}, dtype=kComplexDouble); |
78 | /// torch::fft::fftn(t); |
79 | /// ``` |
80 | inline Tensor fftn( |
81 | const Tensor& self, |
82 | at::OptionalIntArrayRef s = c10::nullopt, |
83 | at::OptionalIntArrayRef dim = c10::nullopt, |
84 | c10::optional<c10::string_view> norm = c10::nullopt) { |
85 | return torch::fft_fftn(self, s, dim, norm); |
86 | } |
87 | |
88 | /// Computes the N dimensional fast Fourier transform over given dimensions. |
89 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.ifftn. |
90 | /// |
91 | /// Example: |
92 | /// ``` |
93 | /// auto t = torch::randn({128, 128}, dtype=kComplexDouble); |
94 | /// torch::fft::ifftn(t); |
95 | /// ``` |
96 | inline Tensor ifftn( |
97 | const Tensor& self, |
98 | at::OptionalIntArrayRef s = c10::nullopt, |
99 | at::OptionalIntArrayRef dim = c10::nullopt, |
100 | c10::optional<c10::string_view> norm = c10::nullopt) { |
101 | return torch::fft_ifftn(self, s, dim, norm); |
102 | } |
103 | |
104 | /// Computes the 1 dimensional FFT of real input with onesided Hermitian output. |
105 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.rfft. |
106 | /// |
107 | /// Example: |
108 | /// ``` |
109 | /// auto t = torch::randn(128); |
110 | /// auto T = torch::fft::rfft(t); |
111 | /// assert(T.is_complex() && T.numel() == 128 / 2 + 1); |
112 | /// ``` |
113 | inline Tensor rfft( |
114 | const Tensor& self, |
115 | c10::optional<int64_t> n = c10::nullopt, |
116 | int64_t dim = -1, |
117 | c10::optional<c10::string_view> norm = c10::nullopt) { |
118 | return torch::fft_rfft(self, n, dim, norm); |
119 | } |
120 | |
121 | /// Computes the inverse of torch.fft.rfft |
122 | /// |
123 | /// The input is a onesided Hermitian Fourier domain signal, with real-valued |
124 | /// output. See https://pytorch.org/docs/master/fft.html#torch.fft.irfft |
125 | /// |
126 | /// Example: |
127 | /// ``` |
128 | /// auto T = torch::randn(128 / 2 + 1, torch::kComplexDouble); |
129 | /// auto t = torch::fft::irfft(t, /*n=*/128); |
130 | /// assert(t.is_floating_point() && T.numel() == 128); |
131 | /// ``` |
132 | inline Tensor irfft( |
133 | const Tensor& self, |
134 | c10::optional<int64_t> n = c10::nullopt, |
135 | int64_t dim = -1, |
136 | c10::optional<c10::string_view> norm = c10::nullopt) { |
137 | return torch::fft_irfft(self, n, dim, norm); |
138 | } |
139 | |
140 | /// Computes the 2-dimensional FFT of real input. Returns a onesided Hermitian |
141 | /// output. See https://pytorch.org/docs/master/fft.html#torch.fft.rfft2 |
142 | /// |
143 | /// Example: |
144 | /// ``` |
145 | /// auto t = torch::randn({128, 128}, dtype=kDouble); |
146 | /// torch::fft::rfft2(t); |
147 | /// ``` |
148 | inline Tensor rfft2( |
149 | const Tensor& self, |
150 | at::OptionalIntArrayRef s = c10::nullopt, |
151 | IntArrayRef dim = {-2, -1}, |
152 | c10::optional<c10::string_view> norm = c10::nullopt) { |
153 | return torch::fft_rfft2(self, s, dim, norm); |
154 | } |
155 | |
156 | /// Computes the inverse of torch.fft.rfft2. |
157 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.irfft2. |
158 | /// |
159 | /// Example: |
160 | /// ``` |
161 | /// auto t = torch::randn({128, 128}, dtype=kComplexDouble); |
162 | /// torch::fft::irfft2(t); |
163 | /// ``` |
164 | inline Tensor irfft2( |
165 | const Tensor& self, |
166 | at::OptionalIntArrayRef s = c10::nullopt, |
167 | IntArrayRef dim = {-2, -1}, |
168 | c10::optional<c10::string_view> norm = c10::nullopt) { |
169 | return torch::fft_irfft2(self, s, dim, norm); |
170 | } |
171 | |
172 | /// Computes the N dimensional FFT of real input with onesided Hermitian output. |
173 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.rfftn |
174 | /// |
175 | /// Example: |
176 | /// ``` |
177 | /// auto t = torch::randn({128, 128}, dtype=kDouble); |
178 | /// torch::fft::rfftn(t); |
179 | /// ``` |
180 | inline Tensor rfftn( |
181 | const Tensor& self, |
182 | at::OptionalIntArrayRef s = c10::nullopt, |
183 | at::OptionalIntArrayRef dim = c10::nullopt, |
184 | c10::optional<c10::string_view> norm = c10::nullopt) { |
185 | return torch::fft_rfftn(self, s, dim, norm); |
186 | } |
187 | |
188 | /// Computes the inverse of torch.fft.rfftn. |
189 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.irfftn. |
190 | /// |
191 | /// Example: |
192 | /// ``` |
193 | /// auto t = torch::randn({128, 128}, dtype=kComplexDouble); |
194 | /// torch::fft::irfftn(t); |
195 | /// ``` |
196 | inline Tensor irfftn( |
197 | const Tensor& self, |
198 | at::OptionalIntArrayRef s = c10::nullopt, |
199 | at::OptionalIntArrayRef dim = c10::nullopt, |
200 | c10::optional<c10::string_view> norm = c10::nullopt) { |
201 | return torch::fft_irfftn(self, s, dim, norm); |
202 | } |
203 | |
204 | /// Computes the 1 dimensional FFT of a onesided Hermitian signal |
205 | /// |
206 | /// The input represents a Hermitian symmetric time domain signal. The returned |
207 | /// Fourier domain representation of such a signal is a real-valued. See |
208 | /// https://pytorch.org/docs/master/fft.html#torch.fft.hfft |
209 | /// |
210 | /// Example: |
211 | /// ``` |
212 | /// auto t = torch::randn(128 / 2 + 1, torch::kComplexDouble); |
213 | /// auto T = torch::fft::hfft(t, /*n=*/128); |
214 | /// assert(T.is_floating_point() && T.numel() == 128); |
215 | /// ``` |
216 | inline Tensor hfft( |
217 | const Tensor& self, |
218 | c10::optional<int64_t> n = c10::nullopt, |
219 | int64_t dim = -1, |
220 | c10::optional<c10::string_view> norm = c10::nullopt) { |
221 | return torch::fft_hfft(self, n, dim, norm); |
222 | } |
223 | |
224 | /// Computes the inverse FFT of a real-valued Fourier domain signal. |
225 | /// |
226 | /// The output is a onesided representation of the Hermitian symmetric time |
227 | /// domain signal. See https://pytorch.org/docs/master/fft.html#torch.fft.ihfft. |
228 | /// |
229 | /// Example: |
230 | /// ``` |
231 | /// auto T = torch::randn(128, torch::kDouble); |
232 | /// auto t = torch::fft::ihfft(T); |
233 | /// assert(t.is_complex() && T.numel() == 128 / 2 + 1); |
234 | /// ``` |
235 | inline Tensor ihfft( |
236 | const Tensor& self, |
237 | c10::optional<int64_t> n = c10::nullopt, |
238 | int64_t dim = -1, |
239 | c10::optional<c10::string_view> norm = c10::nullopt) { |
240 | return torch::fft_ihfft(self, n, dim, norm); |
241 | } |
242 | |
243 | /// Computes the 2-dimensional FFT of a Hermitian symmetric input signal. |
244 | /// |
245 | /// The input is a onesided representation of the Hermitian symmetric time |
246 | /// domain signal. See https://pytorch.org/docs/master/fft.html#torch.fft.hfft2. |
247 | /// |
248 | /// Example: |
249 | /// ``` |
250 | /// auto t = torch::randn({128, 65}, torch::kComplexDouble); |
251 | /// auto T = torch::fft::hfft2(t, /*s=*/{128, 128}); |
252 | /// assert(T.is_floating_point() && T.numel() == 128 * 128); |
253 | /// ``` |
254 | inline Tensor hfft2( |
255 | const Tensor& self, |
256 | at::OptionalIntArrayRef s = c10::nullopt, |
257 | IntArrayRef dim = {-2, -1}, |
258 | c10::optional<c10::string_view> norm = c10::nullopt) { |
259 | return torch::fft_hfft2(self, s, dim, norm); |
260 | } |
261 | |
262 | /// Computes the 2-dimensional IFFT of a real input signal. |
263 | /// |
264 | /// The output is a onesided representation of the Hermitian symmetric time |
265 | /// domain signal. See |
266 | /// https://pytorch.org/docs/master/fft.html#torch.fft.ihfft2. |
267 | /// |
268 | /// Example: |
269 | /// ``` |
270 | /// auto T = torch::randn({128, 128}, torch::kDouble); |
271 | /// auto t = torch::fft::hfft2(T); |
272 | /// assert(t.is_complex() && t.size(1) == 65); |
273 | /// ``` |
274 | inline Tensor ihfft2( |
275 | const Tensor& self, |
276 | at::OptionalIntArrayRef s = c10::nullopt, |
277 | IntArrayRef dim = {-2, -1}, |
278 | c10::optional<c10::string_view> norm = c10::nullopt) { |
279 | return torch::fft_ihfft2(self, s, dim, norm); |
280 | } |
281 | |
282 | /// Computes the N-dimensional FFT of a Hermitian symmetric input signal. |
283 | /// |
284 | /// The input is a onesided representation of the Hermitian symmetric time |
285 | /// domain signal. See https://pytorch.org/docs/master/fft.html#torch.fft.hfftn. |
286 | /// |
287 | /// Example: |
288 | /// ``` |
289 | /// auto t = torch::randn({128, 65}, torch::kComplexDouble); |
290 | /// auto T = torch::fft::hfftn(t, /*s=*/{128, 128}); |
291 | /// assert(T.is_floating_point() && T.numel() == 128 * 128); |
292 | /// ``` |
293 | inline Tensor hfftn( |
294 | const Tensor& self, |
295 | at::OptionalIntArrayRef s = c10::nullopt, |
296 | IntArrayRef dim = {-2, -1}, |
297 | c10::optional<c10::string_view> norm = c10::nullopt) { |
298 | return torch::fft_hfftn(self, s, dim, norm); |
299 | } |
300 | |
301 | /// Computes the N-dimensional IFFT of a real input signal. |
302 | /// |
303 | /// The output is a onesided representation of the Hermitian symmetric time |
304 | /// domain signal. See |
305 | /// https://pytorch.org/docs/master/fft.html#torch.fft.ihfftn. |
306 | /// |
307 | /// Example: |
308 | /// ``` |
309 | /// auto T = torch::randn({128, 128}, torch::kDouble); |
310 | /// auto t = torch::fft::hfft2(T); |
311 | /// assert(t.is_complex() && t.size(1) == 65); |
312 | /// ``` |
313 | inline Tensor ihfftn( |
314 | const Tensor& self, |
315 | at::OptionalIntArrayRef s = c10::nullopt, |
316 | IntArrayRef dim = {-2, -1}, |
317 | c10::optional<c10::string_view> norm = c10::nullopt) { |
318 | return torch::fft_ihfftn(self, s, dim, norm); |
319 | } |
320 | |
321 | /// Computes the discrete Fourier Transform sample frequencies for a signal of |
322 | /// size n. |
323 | /// |
324 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.fftfreq |
325 | /// |
326 | /// Example: |
327 | /// ``` |
328 | /// auto frequencies = torch::fft::fftfreq(128, torch::kDouble); |
329 | /// ``` |
330 | inline Tensor fftfreq(int64_t n, double d, const TensorOptions& options = {}) { |
331 | return torch::fft_fftfreq(n, d, options); |
332 | } |
333 | |
334 | inline Tensor fftfreq(int64_t n, const TensorOptions& options = {}) { |
335 | return torch::fft_fftfreq(n, /*d=*/1.0, options); |
336 | } |
337 | |
338 | /// Computes the sample frequencies for torch.fft.rfft with a signal of size n. |
339 | /// |
340 | /// Like torch.fft.rfft, only the positive frequencies are included. |
341 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.rfftfreq |
342 | /// |
343 | /// Example: |
344 | /// ``` |
345 | /// auto frequencies = torch::fft::rfftfreq(128, torch::kDouble); |
346 | /// ``` |
347 | inline Tensor rfftfreq(int64_t n, double d, const TensorOptions& options) { |
348 | return torch::fft_rfftfreq(n, d, options); |
349 | } |
350 | |
351 | inline Tensor rfftfreq(int64_t n, const TensorOptions& options) { |
352 | return torch::fft_rfftfreq(n, /*d=*/1.0, options); |
353 | } |
354 | |
355 | /// Reorders n-dimensional FFT output to have negative frequency terms first, by |
356 | /// a torch.roll operation. |
357 | /// |
358 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.fftshift |
359 | /// |
360 | /// Example: |
361 | /// ``` |
362 | /// auto x = torch::randn({127, 4}); |
363 | /// auto centred_fft = torch::fft::fftshift(torch::fft::fftn(x)); |
364 | /// ``` |
365 | inline Tensor fftshift( |
366 | const Tensor& x, |
367 | at::OptionalIntArrayRef dim = c10::nullopt) { |
368 | return torch::fft_fftshift(x, dim); |
369 | } |
370 | |
371 | /// Inverse of torch.fft.fftshift |
372 | /// |
373 | /// See https://pytorch.org/docs/master/fft.html#torch.fft.ifftshift |
374 | /// |
375 | /// Example: |
376 | /// ``` |
377 | /// auto x = torch::randn({127, 4}); |
378 | /// auto shift = torch::fft::fftshift(x) |
379 | /// auto unshift = torch::fft::ifftshift(shift); |
380 | /// assert(torch::allclose(x, unshift)); |
381 | /// ``` |
382 | inline Tensor ifftshift( |
383 | const Tensor& x, |
384 | at::OptionalIntArrayRef dim = c10::nullopt) { |
385 | return torch::fft_ifftshift(x, dim); |
386 | } |
387 | |
388 | } // namespace fft |
389 | } // namespace torch |
390 | |