1 | #include <gtest/gtest.h> |
2 | |
3 | #include <c10/util/irange.h> |
4 | #include <test/cpp/api/support.h> |
5 | #include <torch/torch.h> |
6 | |
7 | // Naive DFT of a 1 dimensional tensor |
8 | torch::Tensor naive_dft(torch::Tensor x, bool forward = true) { |
9 | TORCH_INTERNAL_ASSERT(x.dim() == 1); |
10 | x = x.contiguous(); |
11 | auto out_tensor = torch::zeros_like(x); |
12 | const int64_t len = x.size(0); |
13 | |
14 | // Roots of unity, exp(-2*pi*j*n/N) for n in [0, N), reversed for inverse |
15 | // transform |
16 | std::vector<c10::complex<double>> roots(len); |
17 | const auto angle_base = (forward ? -2.0 : 2.0) * M_PI / len; |
18 | for (const auto i : c10::irange(len)) { |
19 | auto angle = i * angle_base; |
20 | roots[i] = c10::complex<double>(std::cos(angle), std::sin(angle)); |
21 | } |
22 | |
23 | const auto in = x.data_ptr<c10::complex<double>>(); |
24 | const auto out = out_tensor.data_ptr<c10::complex<double>>(); |
25 | for (const auto i : c10::irange(len)) { |
26 | for (const auto j : c10::irange(len)) { |
27 | out[i] += roots[(j * i) % len] * in[j]; |
28 | } |
29 | } |
30 | return out_tensor; |
31 | } |
32 | |
33 | // NOTE: Visual Studio and ROCm builds don't understand complex literals |
34 | // as of August 2020 |
35 | |
36 | TEST(FFTTest, fft) { |
37 | auto t = torch::randn(128, torch::kComplexDouble); |
38 | auto actual = torch::fft::fft(t); |
39 | auto expect = naive_dft(t); |
40 | ASSERT_TRUE(torch::allclose(actual, expect)); |
41 | } |
42 | |
43 | TEST(FFTTest, fft_real) { |
44 | auto t = torch::randn(128, torch::kDouble); |
45 | auto actual = torch::fft::fft(t); |
46 | auto expect = torch::fft::fft(t.to(torch::kComplexDouble)); |
47 | ASSERT_TRUE(torch::allclose(actual, expect)); |
48 | } |
49 | |
50 | TEST(FFTTest, fft_pad) { |
51 | auto t = torch::randn(128, torch::kComplexDouble); |
52 | auto actual = torch::fft::fft(t, 200); |
53 | auto expect = torch::fft::fft(torch::constant_pad_nd(t, {0, 72})); |
54 | ASSERT_TRUE(torch::allclose(actual, expect)); |
55 | |
56 | actual = torch::fft::fft(t, 64); |
57 | expect = torch::fft::fft(torch::constant_pad_nd(t, {0, -64})); |
58 | ASSERT_TRUE(torch::allclose(actual, expect)); |
59 | } |
60 | |
61 | TEST(FFTTest, fft_norm) { |
62 | auto t = torch::randn(128, torch::kComplexDouble); |
63 | // NOLINTNEXTLINE(bugprone-argument-comment) |
64 | auto unnorm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/{}); |
65 | // NOLINTNEXTLINE(bugprone-argument-comment) |
66 | auto norm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/"forward" ); |
67 | ASSERT_TRUE(torch::allclose(unnorm / 128, norm)); |
68 | |
69 | // NOLINTNEXTLINE(bugprone-argument-comment) |
70 | auto ortho_norm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/"ortho" ); |
71 | ASSERT_TRUE(torch::allclose(unnorm / std::sqrt(128), ortho_norm)); |
72 | } |
73 | |
74 | TEST(FFTTest, ifft) { |
75 | auto T = torch::randn(128, torch::kComplexDouble); |
76 | auto actual = torch::fft::ifft(T); |
77 | auto expect = naive_dft(T, /*forward=*/false) / 128; |
78 | ASSERT_TRUE(torch::allclose(actual, expect)); |
79 | } |
80 | |
81 | TEST(FFTTest, fft_ifft) { |
82 | auto t = torch::randn(77, torch::kComplexDouble); |
83 | auto T = torch::fft::fft(t); |
84 | ASSERT_EQ(T.size(0), 77); |
85 | ASSERT_EQ(T.scalar_type(), torch::kComplexDouble); |
86 | |
87 | auto t_round_trip = torch::fft::ifft(T); |
88 | ASSERT_EQ(t_round_trip.size(0), 77); |
89 | ASSERT_EQ(t_round_trip.scalar_type(), torch::kComplexDouble); |
90 | ASSERT_TRUE(torch::allclose(t, t_round_trip)); |
91 | } |
92 | |
93 | TEST(FFTTest, rfft) { |
94 | auto t = torch::randn(129, torch::kDouble); |
95 | auto actual = torch::fft::rfft(t); |
96 | auto expect = torch::fft::fft(t.to(torch::kComplexDouble)).slice(0, 0, 65); |
97 | ASSERT_TRUE(torch::allclose(actual, expect)); |
98 | } |
99 | |
100 | TEST(FFTTest, rfft_irfft) { |
101 | auto t = torch::randn(128, torch::kDouble); |
102 | auto T = torch::fft::rfft(t); |
103 | ASSERT_EQ(T.size(0), 65); |
104 | ASSERT_EQ(T.scalar_type(), torch::kComplexDouble); |
105 | |
106 | auto t_round_trip = torch::fft::irfft(T); |
107 | ASSERT_EQ(t_round_trip.size(0), 128); |
108 | ASSERT_EQ(t_round_trip.scalar_type(), torch::kDouble); |
109 | ASSERT_TRUE(torch::allclose(t, t_round_trip)); |
110 | } |
111 | |
112 | TEST(FFTTest, ihfft) { |
113 | auto T = torch::randn(129, torch::kDouble); |
114 | auto actual = torch::fft::ihfft(T); |
115 | auto expect = torch::fft::ifft(T.to(torch::kComplexDouble)).slice(0, 0, 65); |
116 | ASSERT_TRUE(torch::allclose(actual, expect)); |
117 | } |
118 | |
119 | TEST(FFTTest, hfft_ihfft) { |
120 | auto t = torch::randn(64, torch::kComplexDouble); |
121 | t[0] = .5; // Must be purely real to satisfy hermitian symmetry |
122 | auto T = torch::fft::hfft(t, 127); |
123 | ASSERT_EQ(T.size(0), 127); |
124 | ASSERT_EQ(T.scalar_type(), torch::kDouble); |
125 | |
126 | auto t_round_trip = torch::fft::ihfft(T); |
127 | ASSERT_EQ(t_round_trip.size(0), 64); |
128 | ASSERT_EQ(t_round_trip.scalar_type(), torch::kComplexDouble); |
129 | ASSERT_TRUE(torch::allclose(t, t_round_trip)); |
130 | } |
131 | |