1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#include <ATen/core/Tensor.h>
3#include <ATen/Config.h>
4#include <ATen/Dispatch.h>
5#include <ATen/native/Resize.h>
6#include <ATen/native/SpectralOpsUtils.h>
7#include <c10/util/accumulate.h>
8#include <c10/util/irange.h>
9
10#ifndef AT_PER_OPERATOR_HEADERS
11#include <ATen/Functions.h>
12#include <ATen/NativeFunctions.h>
13#else
14#include <ATen/ops/_fft_c2c_native.h>
15#include <ATen/ops/_fft_c2r_native.h>
16#include <ATen/ops/_fft_r2c_native.h>
17#include <ATen/ops/empty.h>
18#endif
19
20#if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
21#include <ATen/Parallel.h>
22#include <ATen/TensorIterator.h>
23
24namespace at { namespace native {
25// In real-to-complex transform, MKL FFT only fills half of the values due to
26// conjugate symmetry. See native/SpectralUtils.h for more details.
27// The following structs are used to fill in the other half with symmetry in
28// case of real-to-complex transform with onesided=False flag.
29// See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h.
30
31template <typename scalar_t>
32static __ubsan_ignore_undefined__ // UBSAN gives false positives on using negative indexes with a pointer
33void _fft_fill_with_conjugate_symmetry_slice(
34 Range range, at::ArrayRef<bool> is_mirrored_dim, IntArrayRef signal_half_sizes,
35 IntArrayRef in_strides, const scalar_t * in_ptr,
36 IntArrayRef out_strides, scalar_t * out_ptr) {
37 const auto ndim = signal_half_sizes.size();
38 DimVector iter_index(ndim, 0);
39
40 // We explicitly loop over one row, then use this lambda to iterate over
41 // n-dimensions. This advances iter_index by one row, while updating in_ptr
42 // and out_ptr to point to the new row of data.
43 auto advance_index = [&] () __ubsan_ignore_undefined__ {
44 for (const auto i : c10::irange(1, iter_index.size())) {
45 if (iter_index[i] + 1 < signal_half_sizes[i]) {
46 ++iter_index[i];
47 in_ptr += in_strides[i];
48 if (is_mirrored_dim[i]) {
49 if (iter_index[i] == 1) {
50 out_ptr += (signal_half_sizes[i] - 1) * out_strides[i];
51 } else {
52 out_ptr -= out_strides[i];
53 }
54 } else {
55 out_ptr += out_strides[i];
56 }
57 return;
58 }
59
60 in_ptr -= in_strides[i] * iter_index[i];
61 if (is_mirrored_dim[i]) {
62 out_ptr -= out_strides[i];
63 } else {
64 out_ptr -= out_strides[i] * iter_index[i];
65 }
66 iter_index[i] = 0;
67 }
68 };
69
70 // The data slice we operate on may start part-way into the data
71 // Update iter_index and pointers to reference the start of the slice
72 if (range.begin > 0) {
73 iter_index[0] = range.begin % signal_half_sizes[0];
74 auto linear_idx = range.begin / signal_half_sizes[0];
75
76 for (size_t i = 1; i < ndim && linear_idx > 0; ++i) {
77 iter_index[i] = linear_idx % signal_half_sizes[i];
78 linear_idx = linear_idx / signal_half_sizes[i];
79
80 if (iter_index[i] > 0) {
81 in_ptr += in_strides[i] * iter_index[i];
82 if (is_mirrored_dim[i]) {
83 out_ptr += out_strides[i] * (signal_half_sizes[i] - iter_index[i]);
84 } else {
85 out_ptr += out_strides[i] * iter_index[i];
86 }
87 }
88 }
89 }
90
91 auto numel_remaining = range.end - range.begin;
92
93 if (is_mirrored_dim[0]) {
94 // Explicitly loop over a Hermitian mirrored dimension
95 if (iter_index[0] > 0) {
96 auto end = std::min(signal_half_sizes[0], iter_index[0] + numel_remaining);
97 for (const auto i : c10::irange(iter_index[0], end)) {
98 out_ptr[(signal_half_sizes[0] - i) * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]);
99 }
100 numel_remaining -= (end - iter_index[0]);
101 iter_index[0] = 0;
102 advance_index();
103 }
104
105 while (numel_remaining > 0) {
106 auto end = std::min(signal_half_sizes[0], numel_remaining);
107 out_ptr[0] = std::conj(in_ptr[0]);
108 for (const auto i : c10::irange(1, end)) {
109 out_ptr[(signal_half_sizes[0] - i) * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]);
110 }
111 numel_remaining -= end;
112 advance_index();
113 }
114 } else {
115 // Explicit loop over a non-mirrored dimension, so just a simple conjugated copy
116 while (numel_remaining > 0) {
117 auto end = std::min(signal_half_sizes[0], iter_index[0] + numel_remaining);
118 for (int64_t i = iter_index[0]; i != end; ++i) {
119 out_ptr[i * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]);
120 }
121 numel_remaining -= (end - iter_index[0]);
122 iter_index[0] = 0;
123 advance_index();
124 }
125 }
126}
127
128static void _fft_fill_with_conjugate_symmetry_cpu_(
129 ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef signal_half_sizes,
130 IntArrayRef in_strides_bytes, const void * in_data,
131 IntArrayRef out_strides_bytes, void * out_data) {
132
133 // Convert strides from bytes to elements
134 const auto element_size = scalarTypeToTypeMeta(dtype).itemsize();
135 const auto ndim = signal_half_sizes.size();
136 DimVector in_strides(ndim), out_strides(ndim);
137 for (const auto i : c10::irange(ndim)) {
138 TORCH_INTERNAL_ASSERT(in_strides_bytes[i] % element_size == 0);
139 in_strides[i] = in_strides_bytes[i] / element_size;
140 TORCH_INTERNAL_ASSERT(out_strides_bytes[i] % element_size == 0);
141 out_strides[i] = out_strides_bytes[i] / element_size;
142 }
143
144 // Construct boolean mask for mirrored dims
145 c10::SmallVector<bool, at::kDimVectorStaticSize> is_mirrored_dim(ndim, false);
146 for (const auto& dim : mirror_dims) {
147 is_mirrored_dim[dim] = true;
148 }
149
150 const auto numel = c10::multiply_integers(signal_half_sizes);
151 AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] {
152 at::parallel_for(0, numel, at::internal::GRAIN_SIZE,
153 [&](int64_t begin, int64_t end) {
154 _fft_fill_with_conjugate_symmetry_slice(
155 {begin, end}, is_mirrored_dim, signal_half_sizes,
156 in_strides, static_cast<const scalar_t*>(in_data),
157 out_strides, static_cast<scalar_t*>(out_data));
158 });
159 });
160}
161
162// Register this one implementation for all cpu types instead of compiling multiple times
163REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fill_with_conjugate_symmetry_cpu_)
164REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
165REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
166REGISTER_ZVECTOR_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
167REGISTER_VSX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_)
168
169// _out variants can be shared between PocketFFT and MKL
170Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
171 bool onesided, Tensor& out) {
172 auto result = _fft_r2c_mkl(self, dim, normalization, /*onesided=*/true);
173 if (onesided) {
174 resize_output(out, result.sizes());
175 return out.copy_(result);
176 }
177
178 resize_output(out, self.sizes());
179
180 auto last_dim = dim.back();
181 auto last_dim_halfsize = result.sizes()[last_dim];
182 auto out_slice = out.slice(last_dim, 0, last_dim_halfsize);
183 out_slice.copy_(result);
184 at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
185 return out;
186}
187
188Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
189 int64_t last_dim_size, Tensor& out) {
190 auto result = _fft_c2r_mkl(self, dim, normalization, last_dim_size);
191 resize_output(out, result.sizes());
192 return out.copy_(result);
193}
194
195Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
196 bool forward, Tensor& out) {
197 auto result = _fft_c2c_mkl(self, dim, normalization, forward);
198 resize_output(out, result.sizes());
199 return out.copy_(result);
200}
201
202}} // namespace at::native
203#endif /* AT_MKL_ENALED() || AT_POCKETFFT_ENABLED() */
204
205#if AT_POCKETFFT_ENABLED()
206#include <pocketfft_hdronly.h>
207
208namespace at { namespace native {
209
210namespace {
211using namespace pocketfft;
212
213stride_t stride_from_tensor(const Tensor& t) {
214 stride_t stride(t.strides().begin(), t.strides().end());
215 for(auto& s: stride) {
216 s *= t.element_size();
217 }
218 return stride;
219}
220
221inline shape_t shape_from_tensor(const Tensor& t) {
222 return shape_t(t.sizes().begin(), t.sizes().end());
223}
224
225template<typename T>
226inline std::complex<T> *tensor_cdata(Tensor& t) {
227 return reinterpret_cast<std::complex<T>*>(t.data_ptr<c10::complex<T>>());
228}
229
230template<typename T>
231inline const std::complex<T> *tensor_cdata(const Tensor& t) {
232 return reinterpret_cast<const std::complex<T>*>(t.data_ptr<c10::complex<T>>());
233}
234
235template<typename T>
236T compute_fct(int64_t size, int64_t normalization) {
237 constexpr auto one = static_cast<T>(1);
238 switch (static_cast<fft_norm_mode>(normalization)) {
239 case fft_norm_mode::none: return one;
240 case fft_norm_mode::by_n: return one / static_cast<T>(size);
241 case fft_norm_mode::by_root_n: return one / std::sqrt(static_cast<T>(size));
242 }
243 AT_ERROR("Unsupported normalization type", normalization);
244}
245
246template<typename T>
247T compute_fct(const Tensor& t, IntArrayRef dim, int64_t normalization) {
248 if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) {
249 return static_cast<T>(1);
250 }
251 const auto& sizes = t.sizes();
252 int64_t n = 1;
253 for(auto idx: dim) {
254 n *= sizes[idx];
255 }
256 return compute_fct<T>(n, normalization);
257}
258
259} // anonymous namespace
260
261Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
262 auto in_sizes = self.sizes();
263 DimVector out_sizes(in_sizes.begin(), in_sizes.end());
264 out_sizes[dim.back()] = last_dim_size;
265 auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type())));
266 pocketfft::shape_t axes(dim.begin(), dim.end());
267 if (self.scalar_type() == kComplexFloat) {
268 pocketfft::c2r(shape_from_tensor(out), stride_from_tensor(self), stride_from_tensor(out), axes, false,
269 tensor_cdata<float>(self),
270 out.data_ptr<float>(), compute_fct<float>(out, dim, normalization));
271 } else {
272 pocketfft::c2r(shape_from_tensor(out), stride_from_tensor(self), stride_from_tensor(out), axes, false,
273 tensor_cdata<double>(self),
274 out.data_ptr<double>(), compute_fct<double>(out, dim, normalization));
275 }
276 return out;
277}
278
279
280Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
281 TORCH_CHECK(self.is_floating_point());
282 auto input_sizes = self.sizes();
283 DimVector out_sizes(input_sizes.begin(), input_sizes.end());
284 auto last_dim = dim.back();
285 auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
286 if (onesided) {
287 out_sizes[last_dim] = last_dim_halfsize;
288 }
289
290 auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
291 pocketfft::shape_t axes(dim.begin(), dim.end());
292 if (self.scalar_type() == kFloat) {
293 pocketfft::r2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, true,
294 self.data_ptr<float>(),
295 tensor_cdata<float>(out), compute_fct<float>(self, dim, normalization));
296 } else {
297 pocketfft::r2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, true,
298 self.data_ptr<double>(),
299 tensor_cdata<double>(out), compute_fct<double>(self, dim, normalization));
300 }
301
302 if (!onesided) {
303 at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
304 }
305 return out;
306}
307
308Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
309 TORCH_CHECK(self.is_complex());
310 auto out = at::empty(self.sizes(), self.options());
311 pocketfft::shape_t axes(dim.begin(), dim.end());
312 if (self.scalar_type() == kComplexFloat) {
313 pocketfft::c2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, forward,
314 tensor_cdata<float>(self),
315 tensor_cdata<float>(out), compute_fct<float>(self, dim, normalization));
316 } else {
317 pocketfft::c2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, forward,
318 tensor_cdata<double>(self),
319 tensor_cdata<double>(out), compute_fct<double>(self, dim, normalization));
320 }
321
322 return out;
323}
324
325}}
326
327#elif AT_MKL_ENABLED()
328#include <ATen/Dispatch.h>
329
330#include <algorithm>
331#include <numeric>
332#include <cmath>
333
334#include <mkl_dfti.h>
335#include <ATen/mkl/Exceptions.h>
336#include <ATen/mkl/Descriptors.h>
337#include <ATen/mkl/Limits.h>
338
339
340namespace at { namespace native {
341
342// Constructs an mkl-fft plan descriptor representing the desired transform
343// For complex types, strides are in units of 2 * element_size(dtype)
344// sizes are for the full signal, including batch size and always two-sided
345static DftiDescriptor _plan_mkl_fft(
346 IntArrayRef in_strides, IntArrayRef out_strides, IntArrayRef sizes,
347 bool complex_input, bool complex_output,
348 int64_t normalization, bool forward, ScalarType dtype) {
349 const int64_t signal_ndim = sizes.size() - 1;
350 TORCH_INTERNAL_ASSERT(in_strides.size() == sizes.size());
351 TORCH_INTERNAL_ASSERT(out_strides.size() == sizes.size());
352
353 // precision
354 const DFTI_CONFIG_VALUE prec = [&]{
355 switch (c10::toRealValueType(dtype)) {
356 case ScalarType::Float: return DFTI_SINGLE;
357 case ScalarType::Double: return DFTI_DOUBLE;
358 default: TORCH_CHECK(false, "MKL FFT doesn't support tensors of type: ", dtype);
359 }
360 }();
361 // signal type
362 const DFTI_CONFIG_VALUE signal_type = [&]{
363 if (forward) {
364 return complex_input ? DFTI_COMPLEX : DFTI_REAL;
365 } else {
366 return complex_output ? DFTI_COMPLEX : DFTI_REAL;
367 }
368 }();
369 // create descriptor with signal size
370 using MklDimVector = c10::SmallVector<MKL_LONG, at::kDimVectorStaticSize>;
371 MklDimVector mkl_signal_sizes(sizes.begin() + 1, sizes.end());
372 DftiDescriptor descriptor;
373 descriptor.init(prec, signal_type, signal_ndim, mkl_signal_sizes.data());
374 // out of place FFT
375 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE));
376 // batch mode
377 MKL_LONG mkl_batch_size = sizes[0];
378 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, mkl_batch_size));
379
380 // batch dim stride, i.e., dist between each data
381 TORCH_CHECK(in_strides[0] <= MKL_LONG_MAX && out_strides[0] <= MKL_LONG_MAX);
382 MKL_LONG idist = in_strides[0];
383 MKL_LONG odist = out_strides[0];
384 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist));
385 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist));
386
387 // signal strides
388 // first val is offset, set to zero (ignored)
389 MklDimVector mkl_istrides(1 + signal_ndim, 0), mkl_ostrides(1 + signal_ndim, 0);
390 for (int64_t i = 1; i <= signal_ndim; i++) {
391 TORCH_CHECK(in_strides[i] <= MKL_LONG_MAX && out_strides[i] <= MKL_LONG_MAX);
392 mkl_istrides[i] = in_strides[i];
393 mkl_ostrides[i] = out_strides[i];
394 }
395 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_istrides.data()));
396 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_ostrides.data()));
397 // if conjugate domain of real is involved, set standard CCE storage type
398 // this will become default in MKL in future
399 if (!complex_input || !complex_output) {
400 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX));
401 }
402 // rescale if requested
403 const auto norm = static_cast<fft_norm_mode>(normalization);
404 int64_t signal_numel = c10::multiply_integers(IntArrayRef(sizes.data() + 1, signal_ndim));
405 if (norm != fft_norm_mode::none) {
406 const double scale = (
407 (norm == fft_norm_mode::by_root_n) ?
408 1.0 / std::sqrt(static_cast<double>(signal_numel)) :
409 1.0 / static_cast<double>(signal_numel));
410 const auto scale_direction = forward ? DFTI_FORWARD_SCALE : DFTI_BACKWARD_SCALE;
411 MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale));
412 }
413
414 if (sizeof(MKL_LONG) < sizeof(int64_t)) {
415 TORCH_CHECK(signal_numel <= MKL_LONG_MAX,
416 "MKL FFT: input signal numel exceeds allowed range [1, ", MKL_LONG_MAX, "]");
417 }
418
419 // finalize
420 MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get()));
421
422 return descriptor;
423}
424
425// Execute a general fft operation (can be c2c, onesided r2c or onesided c2r)
426static Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes,
427 IntArrayRef dim, int64_t normalization, bool forward) {
428 const auto ndim = self.dim();
429 const int64_t signal_ndim = dim.size();
430 const auto batch_dims = ndim - signal_ndim;
431
432 // Permute dimensions so batch dimensions come first, and in stride order
433 // This maximizes data locality when collapsing to a single batch dimension
434 DimVector dim_permute(ndim);
435 std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0});
436
437 c10::SmallVector<bool, kDimVectorStaticSize> is_transformed_dim(ndim);
438 for (const auto& d : dim) {
439 is_transformed_dim[d] = true;
440 }
441 auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(),
442 [&](int64_t d) {return !is_transformed_dim[d]; });
443 auto self_strides = self.strides();
444 std::sort(dim_permute.begin(), batch_end,
445 [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
446 std::copy(dim.cbegin(), dim.cend(), batch_end);
447 auto input = self.permute(dim_permute);
448
449 // Collapse batch dimensions into a single dimension
450 DimVector batched_sizes(signal_ndim + 1);
451 batched_sizes[0] = -1;
452 std::copy(input.sizes().cbegin() + batch_dims, input.sizes().cend(), batched_sizes.begin() + 1);
453 input = input.reshape(batched_sizes);
454
455 const auto batch_size = input.sizes()[0];
456 DimVector signal_size(signal_ndim + 1);
457 signal_size[0] = batch_size;
458 for (const auto i : c10::irange(signal_ndim)) {
459 auto in_size = input.sizes()[i + 1];
460 auto out_size = out_sizes[dim[i]];
461 signal_size[i + 1] = std::max(in_size, out_size);
462 TORCH_INTERNAL_ASSERT(in_size == signal_size[i + 1] ||
463 in_size == (signal_size[i + 1] / 2) + 1);
464 TORCH_INTERNAL_ASSERT(out_size == signal_size[i + 1] ||
465 out_size == (signal_size[i + 1] / 2) + 1);
466 }
467
468 batched_sizes[0] = batch_size;
469 DimVector batched_out_sizes(batched_sizes.begin(), batched_sizes.end());
470 for (const auto i : c10::irange(dim.size())) {
471 batched_out_sizes[i + 1] = out_sizes[dim[i]];
472 }
473
474 const auto value_type = c10::toRealValueType(input.scalar_type());
475 out.resize_(batched_out_sizes, MemoryFormat::Contiguous);
476
477 auto descriptor = _plan_mkl_fft(
478 input.strides(), out.strides(), signal_size, input.is_complex(),
479 out.is_complex(), normalization, forward, value_type);
480
481 // run the FFT
482 if (forward) {
483 MKL_DFTI_CHECK(DftiComputeForward(descriptor.get(), input.data_ptr(), out.data_ptr()));
484 } else {
485 MKL_DFTI_CHECK(DftiComputeBackward(descriptor.get(), input.data_ptr(), out.data_ptr()));
486 }
487
488 // Inplace reshaping to original batch shape and inverting the dimension permutation
489 DimVector out_strides(ndim);
490 int64_t batch_numel = 1;
491 for (int64_t i = batch_dims - 1; i >= 0; --i) {
492 out_strides[dim_permute[i]] = batch_numel * out.strides()[0];
493 batch_numel *= out_sizes[dim_permute[i]];
494 }
495 for (const auto i : c10::irange(batch_dims, ndim)) {
496 out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)];
497 }
498 out.as_strided_(out_sizes, out_strides, out.storage_offset());
499 return out;
500}
501
502// Sort transform dimensions by input layout, for best performance
503// exclude_last is for onesided transforms where the last dimension cannot be reordered
504static DimVector _sort_dims(const Tensor& self, IntArrayRef dim, bool exclude_last=false) {
505 DimVector sorted_dims(dim.begin(), dim.end());
506 auto self_strides = self.strides();
507 std::sort(sorted_dims.begin(), sorted_dims.end() - exclude_last,
508 [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; });
509 return sorted_dims;
510}
511
512// n-dimensional complex to real IFFT
513Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
514 TORCH_CHECK(self.is_complex());
515 // NOTE: Multi-dimensional C2R transforms don't agree with numpy in cases
516 // where the input isn't strictly Hermitian-symmetric. Instead, we use a
517 // multi-dim C2C transform followed by a 1D C2R transform.
518 //
519 // Such inputs are technically out of contract though, so maybe a disagreement
520 // is okay.
521 auto input = self;
522 if (dim.size() > 1) {
523 auto c2c_dims = dim.slice(0, dim.size() - 1);
524 input = _fft_c2c_mkl(self, c2c_dims, normalization, /*forward=*/false);
525 dim = dim.slice(dim.size() - 1);
526 }
527
528 auto in_sizes = input.sizes();
529 DimVector out_sizes(in_sizes.begin(), in_sizes.end());
530 out_sizes[dim.back()] = last_dim_size;
531 auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type())));
532 return _exec_fft(out, input, out_sizes, dim, normalization, /*forward=*/false);
533}
534
535// n-dimensional real to complex FFT
536Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
537 TORCH_CHECK(self.is_floating_point());
538 auto input_sizes = self.sizes();
539 DimVector out_sizes(input_sizes.begin(), input_sizes.end());
540 auto last_dim = dim.back();
541 auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1;
542 if (onesided) {
543 out_sizes[last_dim] = last_dim_halfsize;
544 }
545
546 auto sorted_dims = _sort_dims(self, dim, /*exclude_last=*/true);
547 auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type())));
548 _exec_fft(out, self, out_sizes, sorted_dims, normalization, /*forward=*/true);
549
550 if (!onesided) {
551 at::native::_fft_fill_with_conjugate_symmetry_(out, dim);
552 }
553 return out;
554}
555
556// n-dimensional complex to complex FFT/IFFT
557Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
558 TORCH_CHECK(self.is_complex());
559 const auto sorted_dims = _sort_dims(self, dim);
560 auto out = at::empty(self.sizes(), self.options());
561 return _exec_fft(out, self, self.sizes(), sorted_dims, normalization, forward);
562}
563
564}} // namespace at::native
565
566#else
567
568namespace at { namespace native {
569REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub);
570
571Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) {
572 AT_ERROR("fft: ATen not compiled with FFT support");
573}
574
575Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) {
576 AT_ERROR("fft: ATen not compiled with FFT support");
577}
578
579Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) {
580 AT_ERROR("fft: ATen not compiled with FFT support");
581}
582
583Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
584 bool onesided, Tensor& out) {
585 AT_ERROR("fft: ATen not compiled with FFT support");
586}
587
588Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
589 int64_t last_dim_size, Tensor& out) {
590 AT_ERROR("fft: ATen not compiled with FFT support");
591}
592
593Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization,
594 bool forward, Tensor& out) {
595 AT_ERROR("fft: ATen not compiled with FFT support");
596}
597
598}} // namespace at::native
599#endif
600