1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/platform/errors.h"
17#define EIGEN_USE_THREADS
18
19// See docs in ../ops/fft_ops.cc.
20
21#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22#include "tensorflow/core/framework/op.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_shape.h"
26#include "tensorflow/core/framework/types.h"
27#include "tensorflow/core/lib/core/errors.h"
28#include "tensorflow/core/platform/logging.h"
29#include "tensorflow/core/platform/types.h"
30#include "tensorflow/core/util/env_var.h"
31#include "tensorflow/core/util/work_sharder.h"
32
33#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
34 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
35#include "tensorflow/core/platform/stream_executor.h"
36#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
37
38namespace tensorflow {
39
40class FFTBase : public OpKernel {
41 public:
42 explicit FFTBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
43
44 void Compute(OpKernelContext* ctx) override {
45 const Tensor& in = ctx->input(0);
46 const TensorShape& input_shape = in.shape();
47 const int fft_rank = Rank();
48 OP_REQUIRES(
49 ctx, input_shape.dims() >= fft_rank,
50 errors::InvalidArgument("Input must have rank of at least ", fft_rank,
51 " but got: ", input_shape.DebugString()));
52
53 Tensor* out;
54 TensorShape output_shape = input_shape;
55 uint64 fft_shape[3] = {0, 0, 0};
56
57 // In R2C or C2R mode, we use a second input to specify the FFT length
58 // instead of inferring it from the input shape.
59 if (IsReal()) {
60 const Tensor& fft_length = ctx->input(1);
61 OP_REQUIRES(ctx,
62 fft_length.shape().dims() == 1 &&
63 fft_length.shape().dim_size(0) == fft_rank,
64 errors::InvalidArgument("fft_length must have shape [",
65 fft_rank, "]"));
66
67 auto fft_length_as_vec = fft_length.vec<int32>();
68 for (int i = 0; i < fft_rank; ++i) {
69 OP_REQUIRES(ctx, fft_length_as_vec(i) >= 0,
70 errors::InvalidArgument(
71 "fft_length[", i,
72 "] must >= 0, but got: ", fft_length_as_vec(i)));
73 fft_shape[i] = fft_length_as_vec(i);
74 // Each input dimension must have length of at least fft_shape[i]. For
75 // IRFFTs, the inner-most input dimension must have length of at least
76 // fft_shape[i] / 2 + 1.
77 bool inner_most = (i == fft_rank - 1);
78 uint64 min_input_dim_length =
79 !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i];
80 auto input_index = input_shape.dims() - fft_rank + i;
81 OP_REQUIRES(
82 ctx,
83 // We pass through empty tensors, so special case them here.
84 input_shape.dim_size(input_index) == 0 ||
85 input_shape.dim_size(input_index) >= min_input_dim_length,
86 errors::InvalidArgument(
87 "Input dimension ", input_index,
88 " must have length of at least ", min_input_dim_length,
89 " but got: ", input_shape.dim_size(input_index)));
90 uint64 dim = IsForward() && inner_most && fft_shape[i] != 0
91 ? fft_shape[i] / 2 + 1
92 : fft_shape[i];
93 output_shape.set_dim(output_shape.dims() - fft_rank + i, dim);
94 }
95 } else {
96 for (int i = 0; i < fft_rank; ++i) {
97 fft_shape[i] =
98 output_shape.dim_size(output_shape.dims() - fft_rank + i);
99 }
100 }
101
102 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out));
103
104 if (IsReal()) {
105 if (IsForward()) {
106 OP_REQUIRES(
107 ctx,
108 (in.dtype() == DT_FLOAT && out->dtype() == DT_COMPLEX64) ||
109 (in.dtype() == DT_DOUBLE && out->dtype() == DT_COMPLEX128),
110 errors::InvalidArgument("Wrong types for forward real FFT: in=",
111 in.dtype(), " out=", out->dtype()));
112 } else {
113 OP_REQUIRES(
114 ctx,
115 (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_FLOAT) ||
116 (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_DOUBLE),
117 errors::InvalidArgument("Wrong types for backward real FFT: in=",
118 in.dtype(), " out=", out->dtype()));
119 }
120 } else {
121 OP_REQUIRES(
122 ctx,
123 (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_COMPLEX64) ||
124 (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_COMPLEX128),
125 errors::InvalidArgument("Wrong types for FFT: in=", in.dtype(),
126 " out=", out->dtype()));
127 }
128
129 if (input_shape.num_elements() == 0) {
130 DCHECK_EQ(0, output_shape.num_elements());
131 return;
132 }
133
134 DoFFT(ctx, in, fft_shape, out);
135 }
136
137 protected:
138 virtual int Rank() const = 0;
139 virtual bool IsForward() const = 0;
140 virtual bool IsReal() const = 0;
141
142 // The function that actually computes the FFT.
143 virtual void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
144 Tensor* out) = 0;
145};
146
147typedef Eigen::ThreadPoolDevice CPUDevice;
148
149template <bool Forward, bool _Real, int FFTRank>
150class FFTCPU : public FFTBase {
151 public:
152 using FFTBase::FFTBase;
153
154 protected:
155 int Rank() const override { return FFTRank; }
156 bool IsForward() const override { return Forward; }
157 bool IsReal() const override { return _Real; }
158
159 void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
160 Tensor* out) override {
161 // Create the axes (which are always trailing).
162 const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
163 auto device = ctx->eigen_device<CPUDevice>();
164
165 const bool is_complex128 =
166 in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
167
168 if (!IsReal()) {
169 // Compute the FFT using Eigen.
170 constexpr auto direction =
171 Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE;
172 if (is_complex128) {
173 DCHECK_EQ(in.dtype(), DT_COMPLEX128);
174 DCHECK_EQ(out->dtype(), DT_COMPLEX128);
175 auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>();
176 auto output = out->flat_inner_dims<complex128, FFTRank + 1>();
177 output.device(device) =
178 input.template fft<Eigen::BothParts, direction>(axes);
179 } else {
180 DCHECK_EQ(in.dtype(), DT_COMPLEX64);
181 DCHECK_EQ(out->dtype(), DT_COMPLEX64);
182 auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>();
183 auto output = out->flat_inner_dims<complex64, FFTRank + 1>();
184 output.device(device) =
185 input.template fft<Eigen::BothParts, direction>(axes);
186 }
187 } else {
188 if (IsForward()) {
189 if (is_complex128) {
190 DCHECK_EQ(in.dtype(), DT_DOUBLE);
191 DCHECK_EQ(out->dtype(), DT_COMPLEX128);
192 DoRealForwardFFT<double, complex128>(ctx, fft_shape, in, out);
193 } else {
194 DCHECK_EQ(in.dtype(), DT_FLOAT);
195 DCHECK_EQ(out->dtype(), DT_COMPLEX64);
196 DoRealForwardFFT<float, complex64>(ctx, fft_shape, in, out);
197 }
198 } else {
199 if (is_complex128) {
200 DCHECK_EQ(in.dtype(), DT_COMPLEX128);
201 DCHECK_EQ(out->dtype(), DT_DOUBLE);
202 DoRealBackwardFFT<complex128, double>(ctx, fft_shape, in, out);
203 } else {
204 DCHECK_EQ(in.dtype(), DT_COMPLEX64);
205 DCHECK_EQ(out->dtype(), DT_FLOAT);
206 DoRealBackwardFFT<complex64, float>(ctx, fft_shape, in, out);
207 }
208 }
209 }
210 }
211
212 template <typename RealT, typename ComplexT>
213 void DoRealForwardFFT(OpKernelContext* ctx, uint64* fft_shape,
214 const Tensor& in, Tensor* out) {
215 // Create the axes (which are always trailing).
216 const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank);
217 auto device = ctx->eigen_device<CPUDevice>();
218 auto input = Tensor(in).flat_inner_dims<RealT, FFTRank + 1>();
219 const auto input_dims = input.dimensions();
220
221 // Slice input to fft_shape on its inner-most dimensions.
222 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
223 input_slice_sizes[0] = input_dims[0];
224 TensorShape temp_shape{input_dims[0]};
225 for (int i = 1; i <= FFTRank; ++i) {
226 input_slice_sizes[i] = fft_shape[i - 1];
227 temp_shape.AddDim(fft_shape[i - 1]);
228 }
229 OP_REQUIRES(ctx, temp_shape.num_elements() > 0,
230 errors::InvalidArgument("Obtained a FFT shape of 0 elements: ",
231 temp_shape.DebugString()));
232
233 auto output = out->flat_inner_dims<ComplexT, FFTRank + 1>();
234 const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices;
235
236 // Compute the full FFT using a temporary tensor.
237 Tensor temp;
238 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
239 temp_shape, &temp));
240 auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
241 full_fft.device(device) =
242 input.slice(zero_start_indices, input_slice_sizes)
243 .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes);
244
245 // Slice away the negative frequency components.
246 output.device(device) =
247 full_fft.slice(zero_start_indices, output.dimensions());
248 }
249
250 template <typename ComplexT, typename RealT>
251 void DoRealBackwardFFT(OpKernelContext* ctx, uint64* fft_shape,
252 const Tensor& in, Tensor* out) {
253 auto device = ctx->eigen_device<CPUDevice>();
254 // Reconstruct the full FFT and take the inverse.
255 auto input = Tensor(in).flat_inner_dims<ComplexT, FFTRank + 1>();
256 auto output = out->flat_inner_dims<RealT, FFTRank + 1>();
257 const auto input_dims = input.dimensions();
258
259 // Calculate the shape of the temporary tensor for the full FFT and the
260 // region we will slice from input given fft_shape. We slice input to
261 // fft_shape on its inner-most dimensions, except the last (which we
262 // slice to fft_shape[-1] / 2 + 1).
263 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes;
264 input_slice_sizes[0] = input_dims[0];
265 TensorShape full_fft_shape;
266 full_fft_shape.AddDim(input_dims[0]);
267 for (auto i = 1; i <= FFTRank; i++) {
268 input_slice_sizes[i] =
269 i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1];
270 full_fft_shape.AddDim(fft_shape[i - 1]);
271 }
272 OP_REQUIRES(ctx, full_fft_shape.num_elements() > 0,
273 errors::InvalidArgument("Obtained a FFT shape of 0 elements: ",
274 full_fft_shape.DebugString()));
275
276 Tensor temp;
277 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(),
278 full_fft_shape, &temp));
279 auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>();
280
281 // Calculate the starting point and range of the source of
282 // negative frequency part.
283 auto neg_sizes = input_slice_sizes;
284 neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank];
285 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices;
286 neg_target_indices[FFTRank] = input_slice_sizes[FFTRank];
287
288 const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices;
289 Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices;
290 neg_start_indices[FFTRank] = 1;
291
292 full_fft.slice(start_indices, input_slice_sizes).device(device) =
293 input.slice(start_indices, input_slice_sizes);
294
295 // First, conduct IFFTs on outer dimensions. We save computation (and
296 // avoid touching uninitialized memory) by slicing full_fft to the
297 // subregion we wrote input to.
298 if (FFTRank > 1) {
299 const auto outer_axes =
300 Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1);
301 full_fft.slice(start_indices, input_slice_sizes).device(device) =
302 full_fft.slice(start_indices, input_slice_sizes)
303 .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes);
304 }
305
306 // Reconstruct the full FFT by appending reversed and conjugated
307 // spectrum as the negative frequency part.
308 Eigen::array<bool, FFTRank + 1> reverse_last_axis;
309 for (auto i = 0; i <= FFTRank; i++) {
310 reverse_last_axis[i] = i == FFTRank;
311 }
312
313 if (neg_sizes[FFTRank] != 0) {
314 full_fft.slice(neg_target_indices, neg_sizes).device(device) =
315 full_fft.slice(neg_start_indices, neg_sizes)
316 .reverse(reverse_last_axis)
317 .conjugate();
318 }
319
320 auto inner_axis = Eigen::array<int, 1>{FFTRank};
321 output.device(device) =
322 full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis);
323 }
324};
325
326REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_CPU), FFTCPU<true, false, 1>);
327REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_CPU),
328 FFTCPU<false, false, 1>);
329REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_CPU),
330 FFTCPU<true, false, 2>);
331REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_CPU),
332 FFTCPU<false, false, 2>);
333REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_CPU),
334 FFTCPU<true, false, 3>);
335REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_CPU),
336 FFTCPU<false, false, 3>);
337
338REGISTER_KERNEL_BUILDER(Name("RFFT").Device(DEVICE_CPU), FFTCPU<true, true, 1>);
339REGISTER_KERNEL_BUILDER(Name("IRFFT").Device(DEVICE_CPU),
340 FFTCPU<false, true, 1>);
341REGISTER_KERNEL_BUILDER(Name("RFFT2D").Device(DEVICE_CPU),
342 FFTCPU<true, true, 2>);
343REGISTER_KERNEL_BUILDER(Name("IRFFT2D").Device(DEVICE_CPU),
344 FFTCPU<false, true, 2>);
345REGISTER_KERNEL_BUILDER(Name("RFFT3D").Device(DEVICE_CPU),
346 FFTCPU<true, true, 3>);
347REGISTER_KERNEL_BUILDER(Name("IRFFT3D").Device(DEVICE_CPU),
348 FFTCPU<false, true, 3>);
349
350#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
351 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
352
353namespace {
354template <typename T>
355se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
356 se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
357 se::DeviceMemory<T> typed(wrapped);
358 return typed;
359}
360
361template <typename T>
362se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
363 se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
364 se::DeviceMemory<T> typed(wrapped);
365 return typed;
366}
367
368// A class to provide scratch-space allocator for Stream-Executor Cufft
369// callback. Tensorflow is responsible for releasing the temporary buffers after
370// the kernel finishes.
371// TODO(yangzihao): Refactor redundant code in subclasses of ScratchAllocator
372// into base class.
373class CufftScratchAllocator : public se::ScratchAllocator {
374 public:
375 ~CufftScratchAllocator() override {}
376 CufftScratchAllocator(int64_t memory_limit, OpKernelContext* context)
377 : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
378 int64_t GetMemoryLimitInBytes() override { return memory_limit_; }
379 se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
380 int64_t byte_size) override {
381 Tensor temporary_memory;
382 if (byte_size > memory_limit_) {
383 return se::port::StatusOr<se::DeviceMemory<uint8>>();
384 }
385 AllocationAttributes allocation_attr;
386 allocation_attr.retry_on_failure = false;
387 Status allocation_status(context_->allocate_temp(
388 DT_UINT8, TensorShape({byte_size}), &temporary_memory,
389 AllocatorAttributes(), allocation_attr));
390 if (!allocation_status.ok()) {
391 return se::port::StatusOr<se::DeviceMemory<uint8>>();
392 }
393 // Hold the reference of the allocated tensors until the end of the
394 // allocator.
395 allocated_tensors_.push_back(temporary_memory);
396 total_byte_size_ += byte_size;
397 return se::port::StatusOr<se::DeviceMemory<uint8>>(
398 AsDeviceMemory(temporary_memory.flat<uint8>().data(),
399 temporary_memory.flat<uint8>().size()));
400 }
401 int64_t TotalByteSize() { return total_byte_size_; }
402
403 private:
404 int64_t memory_limit_;
405 int64_t total_byte_size_;
406 OpKernelContext* context_;
407 std::vector<Tensor> allocated_tensors_;
408};
409
410} // end namespace
411
412int64_t GetCufftWorkspaceLimit(const string& envvar_in_mb,
413 int64_t default_value_in_bytes) {
414 const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
415 if (workspace_limit_in_mb_str != nullptr &&
416 strcmp(workspace_limit_in_mb_str, "") != 0) {
417 int64_t scratch_limit_in_mb = -1;
418 Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes,
419 &scratch_limit_in_mb);
420 if (!status.ok()) {
421 LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
422 << workspace_limit_in_mb_str;
423 } else {
424 return scratch_limit_in_mb * (1 << 20);
425 }
426 }
427 return default_value_in_bytes;
428}
429
430class FFTGPUBase : public FFTBase {
431 public:
432 using FFTBase::FFTBase;
433
434 protected:
435 static int64_t CufftScratchSize;
436 void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape,
437 Tensor* out) override {
438 auto* stream = ctx->op_device_context()->stream();
439 OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
440
441 const TensorShape& input_shape = in.shape();
442 const TensorShape& output_shape = out->shape();
443
444 const int fft_rank = Rank();
445 int batch_size = 1;
446 for (int i = 0; i < input_shape.dims() - fft_rank; ++i) {
447 batch_size *= input_shape.dim_size(i);
448 }
449 uint64 input_embed[3];
450 const uint64 input_stride = 1;
451 uint64 input_distance = 1;
452 uint64 output_embed[3];
453 const uint64 output_stride = 1;
454 uint64 output_distance = 1;
455
456 for (int i = 0; i < fft_rank; ++i) {
457 auto dim_offset = input_shape.dims() - fft_rank + i;
458 input_embed[i] = input_shape.dim_size(dim_offset);
459 input_distance *= input_shape.dim_size(dim_offset);
460 output_embed[i] = output_shape.dim_size(dim_offset);
461 output_distance *= output_shape.dim_size(dim_offset);
462 }
463
464 constexpr bool kInPlaceFft = false;
465 const bool is_complex128 =
466 in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128;
467
468 const auto kFftType =
469 IsReal()
470 ? (IsForward()
471 ? (is_complex128 ? se::fft::Type::kD2Z : se::fft::Type::kR2C)
472 : (is_complex128 ? se::fft::Type::kZ2D
473 : se::fft::Type::kC2R))
474 : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward
475 : se::fft::Type::kC2CForward)
476 : (is_complex128 ? se::fft::Type::kZ2ZInverse
477 : se::fft::Type::kC2CInverse));
478
479 CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
480 auto plan =
481 stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator(
482 stream, fft_rank, fft_shape, input_embed, input_stride,
483 input_distance, output_embed, output_stride, output_distance,
484 kFftType, kInPlaceFft, batch_size, &scratch_allocator);
485 OP_REQUIRES(
486 ctx, plan != nullptr,
487 errors::Internal(
488 "Failed to create cuFFT batched plan with scratch allocator"));
489
490 if (IsReal()) {
491 if (IsForward()) {
492 if (is_complex128) {
493 DCHECK_EQ(in.dtype(), DT_DOUBLE);
494 DCHECK_EQ(out->dtype(), DT_COMPLEX128);
495 DoFFTInternal<double, complex128>(ctx, stream, plan.get(), kFftType,
496 output_distance, in, out);
497 } else {
498 DCHECK_EQ(in.dtype(), DT_FLOAT);
499 DCHECK_EQ(out->dtype(), DT_COMPLEX64);
500 DoFFTInternal<float, complex64>(ctx, stream, plan.get(), kFftType,
501 output_distance, in, out);
502 }
503 } else {
504 if (is_complex128) {
505 DCHECK_EQ(in.dtype(), DT_COMPLEX128);
506 DCHECK_EQ(out->dtype(), DT_DOUBLE);
507 DoFFTInternal<complex128, double>(ctx, stream, plan.get(), kFftType,
508 output_distance, in, out);
509 } else {
510 DCHECK_EQ(in.dtype(), DT_COMPLEX64);
511 DCHECK_EQ(out->dtype(), DT_FLOAT);
512 DoFFTInternal<complex64, float>(ctx, stream, plan.get(), kFftType,
513 output_distance, in, out);
514 }
515 }
516 } else {
517 if (is_complex128) {
518 DCHECK_EQ(in.dtype(), DT_COMPLEX128);
519 DCHECK_EQ(out->dtype(), DT_COMPLEX128);
520 DoFFTInternal<complex128, complex128>(ctx, stream, plan.get(), kFftType,
521 output_distance, in, out);
522 } else {
523 DCHECK_EQ(in.dtype(), DT_COMPLEX64);
524 DCHECK_EQ(out->dtype(), DT_COMPLEX64);
525 DoFFTInternal<complex64, complex64>(ctx, stream, plan.get(), kFftType,
526 output_distance, in, out);
527 }
528 }
529 }
530
531 private:
532 template <typename T>
533 struct RealTypeFromComplexType {
534 typedef T RealT;
535 };
536
537 template <typename T>
538 struct RealTypeFromComplexType<std::complex<T>> {
539 typedef T RealT;
540 };
541
542 template <typename InT, typename OutT>
543 void DoFFTInternal(OpKernelContext* ctx, se::Stream* stream,
544 se::fft::Plan* plan, const se::fft::Type fft_type,
545 const uint64 output_distance, const Tensor& in,
546 Tensor* out) {
547 const TensorShape& input_shape = in.shape();
548 const TensorShape& output_shape = out->shape();
549 auto src =
550 AsDeviceMemory<InT>(in.flat<InT>().data(), input_shape.num_elements());
551 auto dst = AsDeviceMemory<OutT>(out->flat<OutT>().data(),
552 output_shape.num_elements());
553 OP_REQUIRES(
554 ctx, stream->ThenFft(plan, src, &dst).ok(),
555 errors::Internal("fft failed : type=", static_cast<int>(fft_type),
556 " in.shape=", input_shape.DebugString()));
557 if (!IsForward()) {
558 typedef typename RealTypeFromComplexType<OutT>::RealT RealT;
559 RealT alpha = 1.0 / output_distance;
560 OP_REQUIRES(
561 ctx,
562 stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1)
563 .ok(),
564 errors::Internal("BlasScal failed : in.shape=",
565 input_shape.DebugString()));
566 }
567 }
568};
569
570int64_t FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit(
571 // default value is in bytes despite the name of the environment variable
572 "TF_CUFFT_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB
573);
574
575template <bool Forward, bool _Real, int FFTRank>
576class FFTGPU : public FFTGPUBase {
577 public:
578 static_assert(FFTRank >= 1 && FFTRank <= 3,
579 "Only 1D, 2D and 3D FFTs supported.");
580 explicit FFTGPU(OpKernelConstruction* ctx) : FFTGPUBase(ctx) {}
581
582 protected:
583 int Rank() const override { return FFTRank; }
584 bool IsForward() const override { return Forward; }
585 bool IsReal() const override { return _Real; }
586};
587
588// Register GPU kernels with priority 1 so that if a custom FFT CPU kernel is
589// registered with priority 1 (to override the default Eigen CPU kernel), the
590// CPU kernel does not outrank the GPU kernel.
591REGISTER_KERNEL_BUILDER(Name("FFT").Device(DEVICE_GPU).Priority(1),
592 FFTGPU<true, false, 1>);
593REGISTER_KERNEL_BUILDER(Name("IFFT").Device(DEVICE_GPU).Priority(1),
594 FFTGPU<false, false, 1>);
595REGISTER_KERNEL_BUILDER(Name("FFT2D").Device(DEVICE_GPU).Priority(1),
596 FFTGPU<true, false, 2>);
597REGISTER_KERNEL_BUILDER(Name("IFFT2D").Device(DEVICE_GPU).Priority(1),
598 FFTGPU<false, false, 2>);
599REGISTER_KERNEL_BUILDER(Name("FFT3D").Device(DEVICE_GPU).Priority(1),
600 FFTGPU<true, false, 3>);
601REGISTER_KERNEL_BUILDER(Name("IFFT3D").Device(DEVICE_GPU).Priority(1),
602 FFTGPU<false, false, 3>);
603
604REGISTER_KERNEL_BUILDER(
605 Name("RFFT").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
606 FFTGPU<true, true, 1>);
607REGISTER_KERNEL_BUILDER(
608 Name("IRFFT").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
609 FFTGPU<false, true, 1>);
610REGISTER_KERNEL_BUILDER(
611 Name("RFFT2D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
612 FFTGPU<true, true, 2>);
613REGISTER_KERNEL_BUILDER(
614 Name("IRFFT2D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
615 FFTGPU<false, true, 2>);
616REGISTER_KERNEL_BUILDER(
617 Name("RFFT3D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
618 FFTGPU<true, true, 3>);
619REGISTER_KERNEL_BUILDER(
620 Name("IRFFT3D").Device(DEVICE_GPU).HostMemory("fft_length").Priority(1),
621 FFTGPU<false, true, 3>);
622
623// Deprecated kernels.
624REGISTER_KERNEL_BUILDER(Name("BatchFFT").Device(DEVICE_GPU).Priority(1),
625 FFTGPU<true, false, 1>);
626REGISTER_KERNEL_BUILDER(Name("BatchIFFT").Device(DEVICE_GPU).Priority(1),
627 FFTGPU<false, false, 1>);
628REGISTER_KERNEL_BUILDER(Name("BatchFFT2D").Device(DEVICE_GPU).Priority(1),
629 FFTGPU<true, false, 2>);
630REGISTER_KERNEL_BUILDER(Name("BatchIFFT2D").Device(DEVICE_GPU).Priority(1),
631 FFTGPU<false, false, 2>);
632REGISTER_KERNEL_BUILDER(Name("BatchFFT3D").Device(DEVICE_GPU).Priority(1),
633 FFTGPU<true, false, 3>);
634REGISTER_KERNEL_BUILDER(Name("BatchIFFT3D").Device(DEVICE_GPU).Priority(1),
635 FFTGPU<false, false, 3>);
636#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
637
638} // end namespace tensorflow
639