1 | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
2 | #include <ATen/core/Tensor.h> |
3 | #include <ATen/Config.h> |
4 | #include <ATen/Dispatch.h> |
5 | #include <ATen/native/Resize.h> |
6 | #include <ATen/native/SpectralOpsUtils.h> |
7 | #include <c10/util/accumulate.h> |
8 | #include <c10/util/irange.h> |
9 | |
10 | #ifndef AT_PER_OPERATOR_HEADERS |
11 | #include <ATen/Functions.h> |
12 | #include <ATen/NativeFunctions.h> |
13 | #else |
14 | #include <ATen/ops/_fft_c2c_native.h> |
15 | #include <ATen/ops/_fft_c2r_native.h> |
16 | #include <ATen/ops/_fft_r2c_native.h> |
17 | #include <ATen/ops/empty.h> |
18 | #endif |
19 | |
20 | #if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED() |
21 | #include <ATen/Parallel.h> |
22 | #include <ATen/TensorIterator.h> |
23 | |
24 | namespace at { namespace native { |
25 | // In real-to-complex transform, MKL FFT only fills half of the values due to |
26 | // conjugate symmetry. See native/SpectralUtils.h for more details. |
27 | // The following structs are used to fill in the other half with symmetry in |
28 | // case of real-to-complex transform with onesided=False flag. |
29 | // See NOTE [ Fourier Transform Conjugate Symmetry ] in native/SpectralOpsUtils.h. |
30 | |
31 | template <typename scalar_t> |
32 | static __ubsan_ignore_undefined__ // UBSAN gives false positives on using negative indexes with a pointer |
33 | void _fft_fill_with_conjugate_symmetry_slice( |
34 | Range range, at::ArrayRef<bool> is_mirrored_dim, IntArrayRef signal_half_sizes, |
35 | IntArrayRef in_strides, const scalar_t * in_ptr, |
36 | IntArrayRef out_strides, scalar_t * out_ptr) { |
37 | const auto ndim = signal_half_sizes.size(); |
38 | DimVector iter_index(ndim, 0); |
39 | |
40 | // We explicitly loop over one row, then use this lambda to iterate over |
41 | // n-dimensions. This advances iter_index by one row, while updating in_ptr |
42 | // and out_ptr to point to the new row of data. |
43 | auto advance_index = [&] () __ubsan_ignore_undefined__ { |
44 | for (const auto i : c10::irange(1, iter_index.size())) { |
45 | if (iter_index[i] + 1 < signal_half_sizes[i]) { |
46 | ++iter_index[i]; |
47 | in_ptr += in_strides[i]; |
48 | if (is_mirrored_dim[i]) { |
49 | if (iter_index[i] == 1) { |
50 | out_ptr += (signal_half_sizes[i] - 1) * out_strides[i]; |
51 | } else { |
52 | out_ptr -= out_strides[i]; |
53 | } |
54 | } else { |
55 | out_ptr += out_strides[i]; |
56 | } |
57 | return; |
58 | } |
59 | |
60 | in_ptr -= in_strides[i] * iter_index[i]; |
61 | if (is_mirrored_dim[i]) { |
62 | out_ptr -= out_strides[i]; |
63 | } else { |
64 | out_ptr -= out_strides[i] * iter_index[i]; |
65 | } |
66 | iter_index[i] = 0; |
67 | } |
68 | }; |
69 | |
70 | // The data slice we operate on may start part-way into the data |
71 | // Update iter_index and pointers to reference the start of the slice |
72 | if (range.begin > 0) { |
73 | iter_index[0] = range.begin % signal_half_sizes[0]; |
74 | auto linear_idx = range.begin / signal_half_sizes[0]; |
75 | |
76 | for (size_t i = 1; i < ndim && linear_idx > 0; ++i) { |
77 | iter_index[i] = linear_idx % signal_half_sizes[i]; |
78 | linear_idx = linear_idx / signal_half_sizes[i]; |
79 | |
80 | if (iter_index[i] > 0) { |
81 | in_ptr += in_strides[i] * iter_index[i]; |
82 | if (is_mirrored_dim[i]) { |
83 | out_ptr += out_strides[i] * (signal_half_sizes[i] - iter_index[i]); |
84 | } else { |
85 | out_ptr += out_strides[i] * iter_index[i]; |
86 | } |
87 | } |
88 | } |
89 | } |
90 | |
91 | auto numel_remaining = range.end - range.begin; |
92 | |
93 | if (is_mirrored_dim[0]) { |
94 | // Explicitly loop over a Hermitian mirrored dimension |
95 | if (iter_index[0] > 0) { |
96 | auto end = std::min(signal_half_sizes[0], iter_index[0] + numel_remaining); |
97 | for (const auto i : c10::irange(iter_index[0], end)) { |
98 | out_ptr[(signal_half_sizes[0] - i) * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]); |
99 | } |
100 | numel_remaining -= (end - iter_index[0]); |
101 | iter_index[0] = 0; |
102 | advance_index(); |
103 | } |
104 | |
105 | while (numel_remaining > 0) { |
106 | auto end = std::min(signal_half_sizes[0], numel_remaining); |
107 | out_ptr[0] = std::conj(in_ptr[0]); |
108 | for (const auto i : c10::irange(1, end)) { |
109 | out_ptr[(signal_half_sizes[0] - i) * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]); |
110 | } |
111 | numel_remaining -= end; |
112 | advance_index(); |
113 | } |
114 | } else { |
115 | // Explicit loop over a non-mirrored dimension, so just a simple conjugated copy |
116 | while (numel_remaining > 0) { |
117 | auto end = std::min(signal_half_sizes[0], iter_index[0] + numel_remaining); |
118 | for (int64_t i = iter_index[0]; i != end; ++i) { |
119 | out_ptr[i * out_strides[0]] = std::conj(in_ptr[i * in_strides[0]]); |
120 | } |
121 | numel_remaining -= (end - iter_index[0]); |
122 | iter_index[0] = 0; |
123 | advance_index(); |
124 | } |
125 | } |
126 | } |
127 | |
128 | static void _fft_fill_with_conjugate_symmetry_cpu_( |
129 | ScalarType dtype, IntArrayRef mirror_dims, IntArrayRef signal_half_sizes, |
130 | IntArrayRef in_strides_bytes, const void * in_data, |
131 | IntArrayRef out_strides_bytes, void * out_data) { |
132 | |
133 | // Convert strides from bytes to elements |
134 | const auto element_size = scalarTypeToTypeMeta(dtype).itemsize(); |
135 | const auto ndim = signal_half_sizes.size(); |
136 | DimVector in_strides(ndim), out_strides(ndim); |
137 | for (const auto i : c10::irange(ndim)) { |
138 | TORCH_INTERNAL_ASSERT(in_strides_bytes[i] % element_size == 0); |
139 | in_strides[i] = in_strides_bytes[i] / element_size; |
140 | TORCH_INTERNAL_ASSERT(out_strides_bytes[i] % element_size == 0); |
141 | out_strides[i] = out_strides_bytes[i] / element_size; |
142 | } |
143 | |
144 | // Construct boolean mask for mirrored dims |
145 | c10::SmallVector<bool, at::kDimVectorStaticSize> is_mirrored_dim(ndim, false); |
146 | for (const auto& dim : mirror_dims) { |
147 | is_mirrored_dim[dim] = true; |
148 | } |
149 | |
150 | const auto numel = c10::multiply_integers(signal_half_sizes); |
151 | AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry" , [&] { |
152 | at::parallel_for(0, numel, at::internal::GRAIN_SIZE, |
153 | [&](int64_t begin, int64_t end) { |
154 | _fft_fill_with_conjugate_symmetry_slice( |
155 | {begin, end}, is_mirrored_dim, signal_half_sizes, |
156 | in_strides, static_cast<const scalar_t*>(in_data), |
157 | out_strides, static_cast<scalar_t*>(out_data)); |
158 | }); |
159 | }); |
160 | } |
161 | |
162 | // Register this one implementation for all cpu types instead of compiling multiple times |
163 | REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fill_with_conjugate_symmetry_cpu_) |
164 | REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) |
165 | REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) |
166 | REGISTER_ZVECTOR_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) |
167 | REGISTER_VSX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) |
168 | |
169 | // _out variants can be shared between PocketFFT and MKL |
170 | Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, |
171 | bool onesided, Tensor& out) { |
172 | auto result = _fft_r2c_mkl(self, dim, normalization, /*onesided=*/true); |
173 | if (onesided) { |
174 | resize_output(out, result.sizes()); |
175 | return out.copy_(result); |
176 | } |
177 | |
178 | resize_output(out, self.sizes()); |
179 | |
180 | auto last_dim = dim.back(); |
181 | auto last_dim_halfsize = result.sizes()[last_dim]; |
182 | auto out_slice = out.slice(last_dim, 0, last_dim_halfsize); |
183 | out_slice.copy_(result); |
184 | at::native::_fft_fill_with_conjugate_symmetry_(out, dim); |
185 | return out; |
186 | } |
187 | |
188 | Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, |
189 | int64_t last_dim_size, Tensor& out) { |
190 | auto result = _fft_c2r_mkl(self, dim, normalization, last_dim_size); |
191 | resize_output(out, result.sizes()); |
192 | return out.copy_(result); |
193 | } |
194 | |
195 | Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, |
196 | bool forward, Tensor& out) { |
197 | auto result = _fft_c2c_mkl(self, dim, normalization, forward); |
198 | resize_output(out, result.sizes()); |
199 | return out.copy_(result); |
200 | } |
201 | |
202 | }} // namespace at::native |
203 | #endif /* AT_MKL_ENALED() || AT_POCKETFFT_ENABLED() */ |
204 | |
205 | #if AT_POCKETFFT_ENABLED() |
206 | #include <pocketfft_hdronly.h> |
207 | |
208 | namespace at { namespace native { |
209 | |
210 | namespace { |
211 | using namespace pocketfft; |
212 | |
213 | stride_t stride_from_tensor(const Tensor& t) { |
214 | stride_t stride(t.strides().begin(), t.strides().end()); |
215 | for(auto& s: stride) { |
216 | s *= t.element_size(); |
217 | } |
218 | return stride; |
219 | } |
220 | |
221 | inline shape_t shape_from_tensor(const Tensor& t) { |
222 | return shape_t(t.sizes().begin(), t.sizes().end()); |
223 | } |
224 | |
225 | template<typename T> |
226 | inline std::complex<T> *tensor_cdata(Tensor& t) { |
227 | return reinterpret_cast<std::complex<T>*>(t.data_ptr<c10::complex<T>>()); |
228 | } |
229 | |
230 | template<typename T> |
231 | inline const std::complex<T> *tensor_cdata(const Tensor& t) { |
232 | return reinterpret_cast<const std::complex<T>*>(t.data_ptr<c10::complex<T>>()); |
233 | } |
234 | |
235 | template<typename T> |
236 | T compute_fct(int64_t size, int64_t normalization) { |
237 | constexpr auto one = static_cast<T>(1); |
238 | switch (static_cast<fft_norm_mode>(normalization)) { |
239 | case fft_norm_mode::none: return one; |
240 | case fft_norm_mode::by_n: return one / static_cast<T>(size); |
241 | case fft_norm_mode::by_root_n: return one / std::sqrt(static_cast<T>(size)); |
242 | } |
243 | AT_ERROR("Unsupported normalization type" , normalization); |
244 | } |
245 | |
246 | template<typename T> |
247 | T compute_fct(const Tensor& t, IntArrayRef dim, int64_t normalization) { |
248 | if (static_cast<fft_norm_mode>(normalization) == fft_norm_mode::none) { |
249 | return static_cast<T>(1); |
250 | } |
251 | const auto& sizes = t.sizes(); |
252 | int64_t n = 1; |
253 | for(auto idx: dim) { |
254 | n *= sizes[idx]; |
255 | } |
256 | return compute_fct<T>(n, normalization); |
257 | } |
258 | |
259 | } // anonymous namespace |
260 | |
261 | Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { |
262 | auto in_sizes = self.sizes(); |
263 | DimVector out_sizes(in_sizes.begin(), in_sizes.end()); |
264 | out_sizes[dim.back()] = last_dim_size; |
265 | auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type()))); |
266 | pocketfft::shape_t axes(dim.begin(), dim.end()); |
267 | if (self.scalar_type() == kComplexFloat) { |
268 | pocketfft::c2r(shape_from_tensor(out), stride_from_tensor(self), stride_from_tensor(out), axes, false, |
269 | tensor_cdata<float>(self), |
270 | out.data_ptr<float>(), compute_fct<float>(out, dim, normalization)); |
271 | } else { |
272 | pocketfft::c2r(shape_from_tensor(out), stride_from_tensor(self), stride_from_tensor(out), axes, false, |
273 | tensor_cdata<double>(self), |
274 | out.data_ptr<double>(), compute_fct<double>(out, dim, normalization)); |
275 | } |
276 | return out; |
277 | } |
278 | |
279 | |
280 | Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { |
281 | TORCH_CHECK(self.is_floating_point()); |
282 | auto input_sizes = self.sizes(); |
283 | DimVector out_sizes(input_sizes.begin(), input_sizes.end()); |
284 | auto last_dim = dim.back(); |
285 | auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1; |
286 | if (onesided) { |
287 | out_sizes[last_dim] = last_dim_halfsize; |
288 | } |
289 | |
290 | auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type()))); |
291 | pocketfft::shape_t axes(dim.begin(), dim.end()); |
292 | if (self.scalar_type() == kFloat) { |
293 | pocketfft::r2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, true, |
294 | self.data_ptr<float>(), |
295 | tensor_cdata<float>(out), compute_fct<float>(self, dim, normalization)); |
296 | } else { |
297 | pocketfft::r2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, true, |
298 | self.data_ptr<double>(), |
299 | tensor_cdata<double>(out), compute_fct<double>(self, dim, normalization)); |
300 | } |
301 | |
302 | if (!onesided) { |
303 | at::native::_fft_fill_with_conjugate_symmetry_(out, dim); |
304 | } |
305 | return out; |
306 | } |
307 | |
308 | Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) { |
309 | TORCH_CHECK(self.is_complex()); |
310 | auto out = at::empty(self.sizes(), self.options()); |
311 | pocketfft::shape_t axes(dim.begin(), dim.end()); |
312 | if (self.scalar_type() == kComplexFloat) { |
313 | pocketfft::c2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, forward, |
314 | tensor_cdata<float>(self), |
315 | tensor_cdata<float>(out), compute_fct<float>(self, dim, normalization)); |
316 | } else { |
317 | pocketfft::c2c(shape_from_tensor(self), stride_from_tensor(self), stride_from_tensor(out), axes, forward, |
318 | tensor_cdata<double>(self), |
319 | tensor_cdata<double>(out), compute_fct<double>(self, dim, normalization)); |
320 | } |
321 | |
322 | return out; |
323 | } |
324 | |
325 | }} |
326 | |
327 | #elif AT_MKL_ENABLED() |
328 | #include <ATen/Dispatch.h> |
329 | |
330 | #include <algorithm> |
331 | #include <numeric> |
332 | #include <cmath> |
333 | |
334 | #include <mkl_dfti.h> |
335 | #include <ATen/mkl/Exceptions.h> |
336 | #include <ATen/mkl/Descriptors.h> |
337 | #include <ATen/mkl/Limits.h> |
338 | |
339 | |
340 | namespace at { namespace native { |
341 | |
342 | // Constructs an mkl-fft plan descriptor representing the desired transform |
343 | // For complex types, strides are in units of 2 * element_size(dtype) |
344 | // sizes are for the full signal, including batch size and always two-sided |
345 | static DftiDescriptor _plan_mkl_fft( |
346 | IntArrayRef in_strides, IntArrayRef out_strides, IntArrayRef sizes, |
347 | bool complex_input, bool complex_output, |
348 | int64_t normalization, bool forward, ScalarType dtype) { |
349 | const int64_t signal_ndim = sizes.size() - 1; |
350 | TORCH_INTERNAL_ASSERT(in_strides.size() == sizes.size()); |
351 | TORCH_INTERNAL_ASSERT(out_strides.size() == sizes.size()); |
352 | |
353 | // precision |
354 | const DFTI_CONFIG_VALUE prec = [&]{ |
355 | switch (c10::toRealValueType(dtype)) { |
356 | case ScalarType::Float: return DFTI_SINGLE; |
357 | case ScalarType::Double: return DFTI_DOUBLE; |
358 | default: TORCH_CHECK(false, "MKL FFT doesn't support tensors of type: " , dtype); |
359 | } |
360 | }(); |
361 | // signal type |
362 | const DFTI_CONFIG_VALUE signal_type = [&]{ |
363 | if (forward) { |
364 | return complex_input ? DFTI_COMPLEX : DFTI_REAL; |
365 | } else { |
366 | return complex_output ? DFTI_COMPLEX : DFTI_REAL; |
367 | } |
368 | }(); |
369 | // create descriptor with signal size |
370 | using MklDimVector = c10::SmallVector<MKL_LONG, at::kDimVectorStaticSize>; |
371 | MklDimVector mkl_signal_sizes(sizes.begin() + 1, sizes.end()); |
372 | DftiDescriptor descriptor; |
373 | descriptor.init(prec, signal_type, signal_ndim, mkl_signal_sizes.data()); |
374 | // out of place FFT |
375 | MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_PLACEMENT, DFTI_NOT_INPLACE)); |
376 | // batch mode |
377 | MKL_LONG mkl_batch_size = sizes[0]; |
378 | MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_NUMBER_OF_TRANSFORMS, mkl_batch_size)); |
379 | |
380 | // batch dim stride, i.e., dist between each data |
381 | TORCH_CHECK(in_strides[0] <= MKL_LONG_MAX && out_strides[0] <= MKL_LONG_MAX); |
382 | MKL_LONG idist = in_strides[0]; |
383 | MKL_LONG odist = out_strides[0]; |
384 | MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_DISTANCE, idist)); |
385 | MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_DISTANCE, odist)); |
386 | |
387 | // signal strides |
388 | // first val is offset, set to zero (ignored) |
389 | MklDimVector mkl_istrides(1 + signal_ndim, 0), mkl_ostrides(1 + signal_ndim, 0); |
390 | for (int64_t i = 1; i <= signal_ndim; i++) { |
391 | TORCH_CHECK(in_strides[i] <= MKL_LONG_MAX && out_strides[i] <= MKL_LONG_MAX); |
392 | mkl_istrides[i] = in_strides[i]; |
393 | mkl_ostrides[i] = out_strides[i]; |
394 | } |
395 | MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_INPUT_STRIDES, mkl_istrides.data())); |
396 | MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_OUTPUT_STRIDES, mkl_ostrides.data())); |
397 | // if conjugate domain of real is involved, set standard CCE storage type |
398 | // this will become default in MKL in future |
399 | if (!complex_input || !complex_output) { |
400 | MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX)); |
401 | } |
402 | // rescale if requested |
403 | const auto norm = static_cast<fft_norm_mode>(normalization); |
404 | int64_t signal_numel = c10::multiply_integers(IntArrayRef(sizes.data() + 1, signal_ndim)); |
405 | if (norm != fft_norm_mode::none) { |
406 | const double scale = ( |
407 | (norm == fft_norm_mode::by_root_n) ? |
408 | 1.0 / std::sqrt(static_cast<double>(signal_numel)) : |
409 | 1.0 / static_cast<double>(signal_numel)); |
410 | const auto scale_direction = forward ? DFTI_FORWARD_SCALE : DFTI_BACKWARD_SCALE; |
411 | MKL_DFTI_CHECK(DftiSetValue(descriptor.get(), scale_direction, scale)); |
412 | } |
413 | |
414 | if (sizeof(MKL_LONG) < sizeof(int64_t)) { |
415 | TORCH_CHECK(signal_numel <= MKL_LONG_MAX, |
416 | "MKL FFT: input signal numel exceeds allowed range [1, " , MKL_LONG_MAX, "]" ); |
417 | } |
418 | |
419 | // finalize |
420 | MKL_DFTI_CHECK(DftiCommitDescriptor(descriptor.get())); |
421 | |
422 | return descriptor; |
423 | } |
424 | |
425 | // Execute a general fft operation (can be c2c, onesided r2c or onesided c2r) |
426 | static Tensor& _exec_fft(Tensor& out, const Tensor& self, IntArrayRef out_sizes, |
427 | IntArrayRef dim, int64_t normalization, bool forward) { |
428 | const auto ndim = self.dim(); |
429 | const int64_t signal_ndim = dim.size(); |
430 | const auto batch_dims = ndim - signal_ndim; |
431 | |
432 | // Permute dimensions so batch dimensions come first, and in stride order |
433 | // This maximizes data locality when collapsing to a single batch dimension |
434 | DimVector dim_permute(ndim); |
435 | std::iota(dim_permute.begin(), dim_permute.end(), int64_t{0}); |
436 | |
437 | c10::SmallVector<bool, kDimVectorStaticSize> is_transformed_dim(ndim); |
438 | for (const auto& d : dim) { |
439 | is_transformed_dim[d] = true; |
440 | } |
441 | auto batch_end = std::partition(dim_permute.begin(), dim_permute.end(), |
442 | [&](int64_t d) {return !is_transformed_dim[d]; }); |
443 | auto self_strides = self.strides(); |
444 | std::sort(dim_permute.begin(), batch_end, |
445 | [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; }); |
446 | std::copy(dim.cbegin(), dim.cend(), batch_end); |
447 | auto input = self.permute(dim_permute); |
448 | |
449 | // Collapse batch dimensions into a single dimension |
450 | DimVector batched_sizes(signal_ndim + 1); |
451 | batched_sizes[0] = -1; |
452 | std::copy(input.sizes().cbegin() + batch_dims, input.sizes().cend(), batched_sizes.begin() + 1); |
453 | input = input.reshape(batched_sizes); |
454 | |
455 | const auto batch_size = input.sizes()[0]; |
456 | DimVector signal_size(signal_ndim + 1); |
457 | signal_size[0] = batch_size; |
458 | for (const auto i : c10::irange(signal_ndim)) { |
459 | auto in_size = input.sizes()[i + 1]; |
460 | auto out_size = out_sizes[dim[i]]; |
461 | signal_size[i + 1] = std::max(in_size, out_size); |
462 | TORCH_INTERNAL_ASSERT(in_size == signal_size[i + 1] || |
463 | in_size == (signal_size[i + 1] / 2) + 1); |
464 | TORCH_INTERNAL_ASSERT(out_size == signal_size[i + 1] || |
465 | out_size == (signal_size[i + 1] / 2) + 1); |
466 | } |
467 | |
468 | batched_sizes[0] = batch_size; |
469 | DimVector batched_out_sizes(batched_sizes.begin(), batched_sizes.end()); |
470 | for (const auto i : c10::irange(dim.size())) { |
471 | batched_out_sizes[i + 1] = out_sizes[dim[i]]; |
472 | } |
473 | |
474 | const auto value_type = c10::toRealValueType(input.scalar_type()); |
475 | out.resize_(batched_out_sizes, MemoryFormat::Contiguous); |
476 | |
477 | auto descriptor = _plan_mkl_fft( |
478 | input.strides(), out.strides(), signal_size, input.is_complex(), |
479 | out.is_complex(), normalization, forward, value_type); |
480 | |
481 | // run the FFT |
482 | if (forward) { |
483 | MKL_DFTI_CHECK(DftiComputeForward(descriptor.get(), input.data_ptr(), out.data_ptr())); |
484 | } else { |
485 | MKL_DFTI_CHECK(DftiComputeBackward(descriptor.get(), input.data_ptr(), out.data_ptr())); |
486 | } |
487 | |
488 | // Inplace reshaping to original batch shape and inverting the dimension permutation |
489 | DimVector out_strides(ndim); |
490 | int64_t batch_numel = 1; |
491 | for (int64_t i = batch_dims - 1; i >= 0; --i) { |
492 | out_strides[dim_permute[i]] = batch_numel * out.strides()[0]; |
493 | batch_numel *= out_sizes[dim_permute[i]]; |
494 | } |
495 | for (const auto i : c10::irange(batch_dims, ndim)) { |
496 | out_strides[dim_permute[i]] = out.strides()[1 + (i - batch_dims)]; |
497 | } |
498 | out.as_strided_(out_sizes, out_strides, out.storage_offset()); |
499 | return out; |
500 | } |
501 | |
502 | // Sort transform dimensions by input layout, for best performance |
503 | // exclude_last is for onesided transforms where the last dimension cannot be reordered |
504 | static DimVector _sort_dims(const Tensor& self, IntArrayRef dim, bool exclude_last=false) { |
505 | DimVector sorted_dims(dim.begin(), dim.end()); |
506 | auto self_strides = self.strides(); |
507 | std::sort(sorted_dims.begin(), sorted_dims.end() - exclude_last, |
508 | [&](int64_t a, int64_t b) { return self_strides[a] > self_strides[b]; }); |
509 | return sorted_dims; |
510 | } |
511 | |
512 | // n-dimensional complex to real IFFT |
513 | Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { |
514 | TORCH_CHECK(self.is_complex()); |
515 | // NOTE: Multi-dimensional C2R transforms don't agree with numpy in cases |
516 | // where the input isn't strictly Hermitian-symmetric. Instead, we use a |
517 | // multi-dim C2C transform followed by a 1D C2R transform. |
518 | // |
519 | // Such inputs are technically out of contract though, so maybe a disagreement |
520 | // is okay. |
521 | auto input = self; |
522 | if (dim.size() > 1) { |
523 | auto c2c_dims = dim.slice(0, dim.size() - 1); |
524 | input = _fft_c2c_mkl(self, c2c_dims, normalization, /*forward=*/false); |
525 | dim = dim.slice(dim.size() - 1); |
526 | } |
527 | |
528 | auto in_sizes = input.sizes(); |
529 | DimVector out_sizes(in_sizes.begin(), in_sizes.end()); |
530 | out_sizes[dim.back()] = last_dim_size; |
531 | auto out = at::empty(out_sizes, self.options().dtype(c10::toRealValueType(self.scalar_type()))); |
532 | return _exec_fft(out, input, out_sizes, dim, normalization, /*forward=*/false); |
533 | } |
534 | |
535 | // n-dimensional real to complex FFT |
536 | Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { |
537 | TORCH_CHECK(self.is_floating_point()); |
538 | auto input_sizes = self.sizes(); |
539 | DimVector out_sizes(input_sizes.begin(), input_sizes.end()); |
540 | auto last_dim = dim.back(); |
541 | auto last_dim_halfsize = (input_sizes[last_dim]) / 2 + 1; |
542 | if (onesided) { |
543 | out_sizes[last_dim] = last_dim_halfsize; |
544 | } |
545 | |
546 | auto sorted_dims = _sort_dims(self, dim, /*exclude_last=*/true); |
547 | auto out = at::empty(out_sizes, self.options().dtype(c10::toComplexType(self.scalar_type()))); |
548 | _exec_fft(out, self, out_sizes, sorted_dims, normalization, /*forward=*/true); |
549 | |
550 | if (!onesided) { |
551 | at::native::_fft_fill_with_conjugate_symmetry_(out, dim); |
552 | } |
553 | return out; |
554 | } |
555 | |
556 | // n-dimensional complex to complex FFT/IFFT |
557 | Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) { |
558 | TORCH_CHECK(self.is_complex()); |
559 | const auto sorted_dims = _sort_dims(self, dim); |
560 | auto out = at::empty(self.sizes(), self.options()); |
561 | return _exec_fft(out, self, self.sizes(), sorted_dims, normalization, forward); |
562 | } |
563 | |
564 | }} // namespace at::native |
565 | |
566 | #else |
567 | |
568 | namespace at { namespace native { |
569 | REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub); |
570 | |
571 | Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { |
572 | AT_ERROR("fft: ATen not compiled with FFT support" ); |
573 | } |
574 | |
575 | Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { |
576 | AT_ERROR("fft: ATen not compiled with FFT support" ); |
577 | } |
578 | |
579 | Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) { |
580 | AT_ERROR("fft: ATen not compiled with FFT support" ); |
581 | } |
582 | |
583 | Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, |
584 | bool onesided, Tensor& out) { |
585 | AT_ERROR("fft: ATen not compiled with FFT support" ); |
586 | } |
587 | |
588 | Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, |
589 | int64_t last_dim_size, Tensor& out) { |
590 | AT_ERROR("fft: ATen not compiled with FFT support" ); |
591 | } |
592 | |
593 | Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, |
594 | bool forward, Tensor& out) { |
595 | AT_ERROR("fft: ATen not compiled with FFT support" ); |
596 | } |
597 | |
598 | }} // namespace at::native |
599 | #endif |
600 | |