1#pragma once
2
3#include <ATen/ATen.h>
4
5namespace torch {
6namespace fft {
7
8/// Computes the 1 dimensional fast Fourier transform over a given dimension.
9/// See https://pytorch.org/docs/master/fft.html#torch.fft.fft.
10///
11/// Example:
12/// ```
13/// auto t = torch::randn(128, dtype=kComplexDouble);
14/// torch::fft::fft(t);
15/// ```
16inline Tensor fft(
17 const Tensor& self,
18 c10::optional<int64_t> n = c10::nullopt,
19 int64_t dim = -1,
20 c10::optional<c10::string_view> norm = c10::nullopt) {
21 return torch::fft_fft(self, n, dim, norm);
22}
23
24/// Computes the 1 dimensional inverse Fourier transform over a given dimension.
25/// See https://pytorch.org/docs/master/fft.html#torch.fft.ifft.
26///
27/// Example:
28/// ```
29/// auto t = torch::randn(128, dtype=kComplexDouble);
30/// torch::fft::ifft(t);
31/// ```
32inline Tensor ifft(
33 const Tensor& self,
34 c10::optional<int64_t> n = c10::nullopt,
35 int64_t dim = -1,
36 c10::optional<c10::string_view> norm = c10::nullopt) {
37 return torch::fft_ifft(self, n, dim, norm);
38}
39
40/// Computes the 2-dimensional fast Fourier transform over the given dimensions.
41/// See https://pytorch.org/docs/master/fft.html#torch.fft.fft2.
42///
43/// Example:
44/// ```
45/// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
46/// torch::fft::fft2(t);
47/// ```
48inline Tensor fft2(
49 const Tensor& self,
50 OptionalIntArrayRef s = c10::nullopt,
51 IntArrayRef dim = {-2, -1},
52 c10::optional<c10::string_view> norm = c10::nullopt) {
53 return torch::fft_fft2(self, s, dim, norm);
54}
55
56/// Computes the inverse of torch.fft.fft2
57/// See https://pytorch.org/docs/master/fft.html#torch.fft.ifft2.
58///
59/// Example:
60/// ```
61/// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
62/// torch::fft::ifft2(t);
63/// ```
64inline Tensor ifft2(
65 const Tensor& self,
66 at::OptionalIntArrayRef s = c10::nullopt,
67 IntArrayRef dim = {-2, -1},
68 c10::optional<c10::string_view> norm = c10::nullopt) {
69 return torch::fft_ifft2(self, s, dim, norm);
70}
71
72/// Computes the N dimensional fast Fourier transform over given dimensions.
73/// See https://pytorch.org/docs/master/fft.html#torch.fft.fftn.
74///
75/// Example:
76/// ```
77/// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
78/// torch::fft::fftn(t);
79/// ```
80inline Tensor fftn(
81 const Tensor& self,
82 at::OptionalIntArrayRef s = c10::nullopt,
83 at::OptionalIntArrayRef dim = c10::nullopt,
84 c10::optional<c10::string_view> norm = c10::nullopt) {
85 return torch::fft_fftn(self, s, dim, norm);
86}
87
88/// Computes the N dimensional fast Fourier transform over given dimensions.
89/// See https://pytorch.org/docs/master/fft.html#torch.fft.ifftn.
90///
91/// Example:
92/// ```
93/// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
94/// torch::fft::ifftn(t);
95/// ```
96inline Tensor ifftn(
97 const Tensor& self,
98 at::OptionalIntArrayRef s = c10::nullopt,
99 at::OptionalIntArrayRef dim = c10::nullopt,
100 c10::optional<c10::string_view> norm = c10::nullopt) {
101 return torch::fft_ifftn(self, s, dim, norm);
102}
103
104/// Computes the 1 dimensional FFT of real input with onesided Hermitian output.
105/// See https://pytorch.org/docs/master/fft.html#torch.fft.rfft.
106///
107/// Example:
108/// ```
109/// auto t = torch::randn(128);
110/// auto T = torch::fft::rfft(t);
111/// assert(T.is_complex() && T.numel() == 128 / 2 + 1);
112/// ```
113inline Tensor rfft(
114 const Tensor& self,
115 c10::optional<int64_t> n = c10::nullopt,
116 int64_t dim = -1,
117 c10::optional<c10::string_view> norm = c10::nullopt) {
118 return torch::fft_rfft(self, n, dim, norm);
119}
120
121/// Computes the inverse of torch.fft.rfft
122///
123/// The input is a onesided Hermitian Fourier domain signal, with real-valued
124/// output. See https://pytorch.org/docs/master/fft.html#torch.fft.irfft
125///
126/// Example:
127/// ```
128/// auto T = torch::randn(128 / 2 + 1, torch::kComplexDouble);
129/// auto t = torch::fft::irfft(t, /*n=*/128);
130/// assert(t.is_floating_point() && T.numel() == 128);
131/// ```
132inline Tensor irfft(
133 const Tensor& self,
134 c10::optional<int64_t> n = c10::nullopt,
135 int64_t dim = -1,
136 c10::optional<c10::string_view> norm = c10::nullopt) {
137 return torch::fft_irfft(self, n, dim, norm);
138}
139
140/// Computes the 2-dimensional FFT of real input. Returns a onesided Hermitian
141/// output. See https://pytorch.org/docs/master/fft.html#torch.fft.rfft2
142///
143/// Example:
144/// ```
145/// auto t = torch::randn({128, 128}, dtype=kDouble);
146/// torch::fft::rfft2(t);
147/// ```
148inline Tensor rfft2(
149 const Tensor& self,
150 at::OptionalIntArrayRef s = c10::nullopt,
151 IntArrayRef dim = {-2, -1},
152 c10::optional<c10::string_view> norm = c10::nullopt) {
153 return torch::fft_rfft2(self, s, dim, norm);
154}
155
156/// Computes the inverse of torch.fft.rfft2.
157/// See https://pytorch.org/docs/master/fft.html#torch.fft.irfft2.
158///
159/// Example:
160/// ```
161/// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
162/// torch::fft::irfft2(t);
163/// ```
164inline Tensor irfft2(
165 const Tensor& self,
166 at::OptionalIntArrayRef s = c10::nullopt,
167 IntArrayRef dim = {-2, -1},
168 c10::optional<c10::string_view> norm = c10::nullopt) {
169 return torch::fft_irfft2(self, s, dim, norm);
170}
171
172/// Computes the N dimensional FFT of real input with onesided Hermitian output.
173/// See https://pytorch.org/docs/master/fft.html#torch.fft.rfftn
174///
175/// Example:
176/// ```
177/// auto t = torch::randn({128, 128}, dtype=kDouble);
178/// torch::fft::rfftn(t);
179/// ```
180inline Tensor rfftn(
181 const Tensor& self,
182 at::OptionalIntArrayRef s = c10::nullopt,
183 at::OptionalIntArrayRef dim = c10::nullopt,
184 c10::optional<c10::string_view> norm = c10::nullopt) {
185 return torch::fft_rfftn(self, s, dim, norm);
186}
187
188/// Computes the inverse of torch.fft.rfftn.
189/// See https://pytorch.org/docs/master/fft.html#torch.fft.irfftn.
190///
191/// Example:
192/// ```
193/// auto t = torch::randn({128, 128}, dtype=kComplexDouble);
194/// torch::fft::irfftn(t);
195/// ```
196inline Tensor irfftn(
197 const Tensor& self,
198 at::OptionalIntArrayRef s = c10::nullopt,
199 at::OptionalIntArrayRef dim = c10::nullopt,
200 c10::optional<c10::string_view> norm = c10::nullopt) {
201 return torch::fft_irfftn(self, s, dim, norm);
202}
203
204/// Computes the 1 dimensional FFT of a onesided Hermitian signal
205///
206/// The input represents a Hermitian symmetric time domain signal. The returned
207/// Fourier domain representation of such a signal is a real-valued. See
208/// https://pytorch.org/docs/master/fft.html#torch.fft.hfft
209///
210/// Example:
211/// ```
212/// auto t = torch::randn(128 / 2 + 1, torch::kComplexDouble);
213/// auto T = torch::fft::hfft(t, /*n=*/128);
214/// assert(T.is_floating_point() && T.numel() == 128);
215/// ```
216inline Tensor hfft(
217 const Tensor& self,
218 c10::optional<int64_t> n = c10::nullopt,
219 int64_t dim = -1,
220 c10::optional<c10::string_view> norm = c10::nullopt) {
221 return torch::fft_hfft(self, n, dim, norm);
222}
223
224/// Computes the inverse FFT of a real-valued Fourier domain signal.
225///
226/// The output is a onesided representation of the Hermitian symmetric time
227/// domain signal. See https://pytorch.org/docs/master/fft.html#torch.fft.ihfft.
228///
229/// Example:
230/// ```
231/// auto T = torch::randn(128, torch::kDouble);
232/// auto t = torch::fft::ihfft(T);
233/// assert(t.is_complex() && T.numel() == 128 / 2 + 1);
234/// ```
235inline Tensor ihfft(
236 const Tensor& self,
237 c10::optional<int64_t> n = c10::nullopt,
238 int64_t dim = -1,
239 c10::optional<c10::string_view> norm = c10::nullopt) {
240 return torch::fft_ihfft(self, n, dim, norm);
241}
242
243/// Computes the 2-dimensional FFT of a Hermitian symmetric input signal.
244///
245/// The input is a onesided representation of the Hermitian symmetric time
246/// domain signal. See https://pytorch.org/docs/master/fft.html#torch.fft.hfft2.
247///
248/// Example:
249/// ```
250/// auto t = torch::randn({128, 65}, torch::kComplexDouble);
251/// auto T = torch::fft::hfft2(t, /*s=*/{128, 128});
252/// assert(T.is_floating_point() && T.numel() == 128 * 128);
253/// ```
254inline Tensor hfft2(
255 const Tensor& self,
256 at::OptionalIntArrayRef s = c10::nullopt,
257 IntArrayRef dim = {-2, -1},
258 c10::optional<c10::string_view> norm = c10::nullopt) {
259 return torch::fft_hfft2(self, s, dim, norm);
260}
261
262/// Computes the 2-dimensional IFFT of a real input signal.
263///
264/// The output is a onesided representation of the Hermitian symmetric time
265/// domain signal. See
266/// https://pytorch.org/docs/master/fft.html#torch.fft.ihfft2.
267///
268/// Example:
269/// ```
270/// auto T = torch::randn({128, 128}, torch::kDouble);
271/// auto t = torch::fft::hfft2(T);
272/// assert(t.is_complex() && t.size(1) == 65);
273/// ```
274inline Tensor ihfft2(
275 const Tensor& self,
276 at::OptionalIntArrayRef s = c10::nullopt,
277 IntArrayRef dim = {-2, -1},
278 c10::optional<c10::string_view> norm = c10::nullopt) {
279 return torch::fft_ihfft2(self, s, dim, norm);
280}
281
282/// Computes the N-dimensional FFT of a Hermitian symmetric input signal.
283///
284/// The input is a onesided representation of the Hermitian symmetric time
285/// domain signal. See https://pytorch.org/docs/master/fft.html#torch.fft.hfftn.
286///
287/// Example:
288/// ```
289/// auto t = torch::randn({128, 65}, torch::kComplexDouble);
290/// auto T = torch::fft::hfftn(t, /*s=*/{128, 128});
291/// assert(T.is_floating_point() && T.numel() == 128 * 128);
292/// ```
293inline Tensor hfftn(
294 const Tensor& self,
295 at::OptionalIntArrayRef s = c10::nullopt,
296 IntArrayRef dim = {-2, -1},
297 c10::optional<c10::string_view> norm = c10::nullopt) {
298 return torch::fft_hfftn(self, s, dim, norm);
299}
300
301/// Computes the N-dimensional IFFT of a real input signal.
302///
303/// The output is a onesided representation of the Hermitian symmetric time
304/// domain signal. See
305/// https://pytorch.org/docs/master/fft.html#torch.fft.ihfftn.
306///
307/// Example:
308/// ```
309/// auto T = torch::randn({128, 128}, torch::kDouble);
310/// auto t = torch::fft::hfft2(T);
311/// assert(t.is_complex() && t.size(1) == 65);
312/// ```
313inline Tensor ihfftn(
314 const Tensor& self,
315 at::OptionalIntArrayRef s = c10::nullopt,
316 IntArrayRef dim = {-2, -1},
317 c10::optional<c10::string_view> norm = c10::nullopt) {
318 return torch::fft_ihfftn(self, s, dim, norm);
319}
320
321/// Computes the discrete Fourier Transform sample frequencies for a signal of
322/// size n.
323///
324/// See https://pytorch.org/docs/master/fft.html#torch.fft.fftfreq
325///
326/// Example:
327/// ```
328/// auto frequencies = torch::fft::fftfreq(128, torch::kDouble);
329/// ```
330inline Tensor fftfreq(int64_t n, double d, const TensorOptions& options = {}) {
331 return torch::fft_fftfreq(n, d, options);
332}
333
334inline Tensor fftfreq(int64_t n, const TensorOptions& options = {}) {
335 return torch::fft_fftfreq(n, /*d=*/1.0, options);
336}
337
338/// Computes the sample frequencies for torch.fft.rfft with a signal of size n.
339///
340/// Like torch.fft.rfft, only the positive frequencies are included.
341/// See https://pytorch.org/docs/master/fft.html#torch.fft.rfftfreq
342///
343/// Example:
344/// ```
345/// auto frequencies = torch::fft::rfftfreq(128, torch::kDouble);
346/// ```
347inline Tensor rfftfreq(int64_t n, double d, const TensorOptions& options) {
348 return torch::fft_rfftfreq(n, d, options);
349}
350
351inline Tensor rfftfreq(int64_t n, const TensorOptions& options) {
352 return torch::fft_rfftfreq(n, /*d=*/1.0, options);
353}
354
355/// Reorders n-dimensional FFT output to have negative frequency terms first, by
356/// a torch.roll operation.
357///
358/// See https://pytorch.org/docs/master/fft.html#torch.fft.fftshift
359///
360/// Example:
361/// ```
362/// auto x = torch::randn({127, 4});
363/// auto centred_fft = torch::fft::fftshift(torch::fft::fftn(x));
364/// ```
365inline Tensor fftshift(
366 const Tensor& x,
367 at::OptionalIntArrayRef dim = c10::nullopt) {
368 return torch::fft_fftshift(x, dim);
369}
370
371/// Inverse of torch.fft.fftshift
372///
373/// See https://pytorch.org/docs/master/fft.html#torch.fft.ifftshift
374///
375/// Example:
376/// ```
377/// auto x = torch::randn({127, 4});
378/// auto shift = torch::fft::fftshift(x)
379/// auto unshift = torch::fft::ifftshift(shift);
380/// assert(torch::allclose(x, unshift));
381/// ```
382inline Tensor ifftshift(
383 const Tensor& x,
384 at::OptionalIntArrayRef dim = c10::nullopt) {
385 return torch::fft_ifftshift(x, dim);
386}
387
388} // namespace fft
389} // namespace torch
390