1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/platform/errors.h" |
17 | #define EIGEN_USE_THREADS |
18 | |
19 | // See docs in ../ops/fft_ops.cc. |
20 | |
21 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
22 | #include "tensorflow/core/framework/op.h" |
23 | #include "tensorflow/core/framework/op_kernel.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/tensor_shape.h" |
26 | #include "tensorflow/core/framework/types.h" |
27 | #include "tensorflow/core/lib/core/errors.h" |
28 | #include "tensorflow/core/platform/logging.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | #include "tensorflow/core/util/env_var.h" |
31 | #include "tensorflow/core/util/work_sharder.h" |
32 | |
33 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
34 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
35 | #include "tensorflow/core/platform/stream_executor.h" |
36 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
37 | |
38 | namespace tensorflow { |
39 | |
40 | class FFTBase : public OpKernel { |
41 | public: |
42 | explicit FFTBase(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
43 | |
44 | void Compute(OpKernelContext* ctx) override { |
45 | const Tensor& in = ctx->input(0); |
46 | const TensorShape& input_shape = in.shape(); |
47 | const int fft_rank = Rank(); |
48 | OP_REQUIRES( |
49 | ctx, input_shape.dims() >= fft_rank, |
50 | errors::InvalidArgument("Input must have rank of at least " , fft_rank, |
51 | " but got: " , input_shape.DebugString())); |
52 | |
53 | Tensor* out; |
54 | TensorShape output_shape = input_shape; |
55 | uint64 fft_shape[3] = {0, 0, 0}; |
56 | |
57 | // In R2C or C2R mode, we use a second input to specify the FFT length |
58 | // instead of inferring it from the input shape. |
59 | if (IsReal()) { |
60 | const Tensor& fft_length = ctx->input(1); |
61 | OP_REQUIRES(ctx, |
62 | fft_length.shape().dims() == 1 && |
63 | fft_length.shape().dim_size(0) == fft_rank, |
64 | errors::InvalidArgument("fft_length must have shape [" , |
65 | fft_rank, "]" )); |
66 | |
67 | auto fft_length_as_vec = fft_length.vec<int32>(); |
68 | for (int i = 0; i < fft_rank; ++i) { |
69 | OP_REQUIRES(ctx, fft_length_as_vec(i) >= 0, |
70 | errors::InvalidArgument( |
71 | "fft_length[" , i, |
72 | "] must >= 0, but got: " , fft_length_as_vec(i))); |
73 | fft_shape[i] = fft_length_as_vec(i); |
74 | // Each input dimension must have length of at least fft_shape[i]. For |
75 | // IRFFTs, the inner-most input dimension must have length of at least |
76 | // fft_shape[i] / 2 + 1. |
77 | bool inner_most = (i == fft_rank - 1); |
78 | uint64 min_input_dim_length = |
79 | !IsForward() && inner_most ? fft_shape[i] / 2 + 1 : fft_shape[i]; |
80 | auto input_index = input_shape.dims() - fft_rank + i; |
81 | OP_REQUIRES( |
82 | ctx, |
83 | // We pass through empty tensors, so special case them here. |
84 | input_shape.dim_size(input_index) == 0 || |
85 | input_shape.dim_size(input_index) >= min_input_dim_length, |
86 | errors::InvalidArgument( |
87 | "Input dimension " , input_index, |
88 | " must have length of at least " , min_input_dim_length, |
89 | " but got: " , input_shape.dim_size(input_index))); |
90 | uint64 dim = IsForward() && inner_most && fft_shape[i] != 0 |
91 | ? fft_shape[i] / 2 + 1 |
92 | : fft_shape[i]; |
93 | output_shape.set_dim(output_shape.dims() - fft_rank + i, dim); |
94 | } |
95 | } else { |
96 | for (int i = 0; i < fft_rank; ++i) { |
97 | fft_shape[i] = |
98 | output_shape.dim_size(output_shape.dims() - fft_rank + i); |
99 | } |
100 | } |
101 | |
102 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &out)); |
103 | |
104 | if (IsReal()) { |
105 | if (IsForward()) { |
106 | OP_REQUIRES( |
107 | ctx, |
108 | (in.dtype() == DT_FLOAT && out->dtype() == DT_COMPLEX64) || |
109 | (in.dtype() == DT_DOUBLE && out->dtype() == DT_COMPLEX128), |
110 | errors::InvalidArgument("Wrong types for forward real FFT: in=" , |
111 | in.dtype(), " out=" , out->dtype())); |
112 | } else { |
113 | OP_REQUIRES( |
114 | ctx, |
115 | (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_FLOAT) || |
116 | (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_DOUBLE), |
117 | errors::InvalidArgument("Wrong types for backward real FFT: in=" , |
118 | in.dtype(), " out=" , out->dtype())); |
119 | } |
120 | } else { |
121 | OP_REQUIRES( |
122 | ctx, |
123 | (in.dtype() == DT_COMPLEX64 && out->dtype() == DT_COMPLEX64) || |
124 | (in.dtype() == DT_COMPLEX128 && out->dtype() == DT_COMPLEX128), |
125 | errors::InvalidArgument("Wrong types for FFT: in=" , in.dtype(), |
126 | " out=" , out->dtype())); |
127 | } |
128 | |
129 | if (input_shape.num_elements() == 0) { |
130 | DCHECK_EQ(0, output_shape.num_elements()); |
131 | return; |
132 | } |
133 | |
134 | DoFFT(ctx, in, fft_shape, out); |
135 | } |
136 | |
137 | protected: |
138 | virtual int Rank() const = 0; |
139 | virtual bool IsForward() const = 0; |
140 | virtual bool IsReal() const = 0; |
141 | |
142 | // The function that actually computes the FFT. |
143 | virtual void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape, |
144 | Tensor* out) = 0; |
145 | }; |
146 | |
147 | typedef Eigen::ThreadPoolDevice CPUDevice; |
148 | |
149 | template <bool Forward, bool _Real, int FFTRank> |
150 | class FFTCPU : public FFTBase { |
151 | public: |
152 | using FFTBase::FFTBase; |
153 | |
154 | protected: |
155 | int Rank() const override { return FFTRank; } |
156 | bool IsForward() const override { return Forward; } |
157 | bool IsReal() const override { return _Real; } |
158 | |
159 | void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape, |
160 | Tensor* out) override { |
161 | // Create the axes (which are always trailing). |
162 | const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); |
163 | auto device = ctx->eigen_device<CPUDevice>(); |
164 | |
165 | const bool is_complex128 = |
166 | in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128; |
167 | |
168 | if (!IsReal()) { |
169 | // Compute the FFT using Eigen. |
170 | constexpr auto direction = |
171 | Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE; |
172 | if (is_complex128) { |
173 | DCHECK_EQ(in.dtype(), DT_COMPLEX128); |
174 | DCHECK_EQ(out->dtype(), DT_COMPLEX128); |
175 | auto input = Tensor(in).flat_inner_dims<complex128, FFTRank + 1>(); |
176 | auto output = out->flat_inner_dims<complex128, FFTRank + 1>(); |
177 | output.device(device) = |
178 | input.template fft<Eigen::BothParts, direction>(axes); |
179 | } else { |
180 | DCHECK_EQ(in.dtype(), DT_COMPLEX64); |
181 | DCHECK_EQ(out->dtype(), DT_COMPLEX64); |
182 | auto input = Tensor(in).flat_inner_dims<complex64, FFTRank + 1>(); |
183 | auto output = out->flat_inner_dims<complex64, FFTRank + 1>(); |
184 | output.device(device) = |
185 | input.template fft<Eigen::BothParts, direction>(axes); |
186 | } |
187 | } else { |
188 | if (IsForward()) { |
189 | if (is_complex128) { |
190 | DCHECK_EQ(in.dtype(), DT_DOUBLE); |
191 | DCHECK_EQ(out->dtype(), DT_COMPLEX128); |
192 | DoRealForwardFFT<double, complex128>(ctx, fft_shape, in, out); |
193 | } else { |
194 | DCHECK_EQ(in.dtype(), DT_FLOAT); |
195 | DCHECK_EQ(out->dtype(), DT_COMPLEX64); |
196 | DoRealForwardFFT<float, complex64>(ctx, fft_shape, in, out); |
197 | } |
198 | } else { |
199 | if (is_complex128) { |
200 | DCHECK_EQ(in.dtype(), DT_COMPLEX128); |
201 | DCHECK_EQ(out->dtype(), DT_DOUBLE); |
202 | DoRealBackwardFFT<complex128, double>(ctx, fft_shape, in, out); |
203 | } else { |
204 | DCHECK_EQ(in.dtype(), DT_COMPLEX64); |
205 | DCHECK_EQ(out->dtype(), DT_FLOAT); |
206 | DoRealBackwardFFT<complex64, float>(ctx, fft_shape, in, out); |
207 | } |
208 | } |
209 | } |
210 | } |
211 | |
212 | template <typename RealT, typename ComplexT> |
213 | void DoRealForwardFFT(OpKernelContext* ctx, uint64* fft_shape, |
214 | const Tensor& in, Tensor* out) { |
215 | // Create the axes (which are always trailing). |
216 | const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); |
217 | auto device = ctx->eigen_device<CPUDevice>(); |
218 | auto input = Tensor(in).flat_inner_dims<RealT, FFTRank + 1>(); |
219 | const auto input_dims = input.dimensions(); |
220 | |
221 | // Slice input to fft_shape on its inner-most dimensions. |
222 | Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes; |
223 | input_slice_sizes[0] = input_dims[0]; |
224 | TensorShape temp_shape{input_dims[0]}; |
225 | for (int i = 1; i <= FFTRank; ++i) { |
226 | input_slice_sizes[i] = fft_shape[i - 1]; |
227 | temp_shape.AddDim(fft_shape[i - 1]); |
228 | } |
229 | OP_REQUIRES(ctx, temp_shape.num_elements() > 0, |
230 | errors::InvalidArgument("Obtained a FFT shape of 0 elements: " , |
231 | temp_shape.DebugString())); |
232 | |
233 | auto output = out->flat_inner_dims<ComplexT, FFTRank + 1>(); |
234 | const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> zero_start_indices; |
235 | |
236 | // Compute the full FFT using a temporary tensor. |
237 | Tensor temp; |
238 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(), |
239 | temp_shape, &temp)); |
240 | auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>(); |
241 | full_fft.device(device) = |
242 | input.slice(zero_start_indices, input_slice_sizes) |
243 | .template fft<Eigen::BothParts, Eigen::FFT_FORWARD>(axes); |
244 | |
245 | // Slice away the negative frequency components. |
246 | output.device(device) = |
247 | full_fft.slice(zero_start_indices, output.dimensions()); |
248 | } |
249 | |
250 | template <typename ComplexT, typename RealT> |
251 | void DoRealBackwardFFT(OpKernelContext* ctx, uint64* fft_shape, |
252 | const Tensor& in, Tensor* out) { |
253 | auto device = ctx->eigen_device<CPUDevice>(); |
254 | // Reconstruct the full FFT and take the inverse. |
255 | auto input = Tensor(in).flat_inner_dims<ComplexT, FFTRank + 1>(); |
256 | auto output = out->flat_inner_dims<RealT, FFTRank + 1>(); |
257 | const auto input_dims = input.dimensions(); |
258 | |
259 | // Calculate the shape of the temporary tensor for the full FFT and the |
260 | // region we will slice from input given fft_shape. We slice input to |
261 | // fft_shape on its inner-most dimensions, except the last (which we |
262 | // slice to fft_shape[-1] / 2 + 1). |
263 | Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> input_slice_sizes; |
264 | input_slice_sizes[0] = input_dims[0]; |
265 | TensorShape full_fft_shape; |
266 | full_fft_shape.AddDim(input_dims[0]); |
267 | for (auto i = 1; i <= FFTRank; i++) { |
268 | input_slice_sizes[i] = |
269 | i == FFTRank ? fft_shape[i - 1] / 2 + 1 : fft_shape[i - 1]; |
270 | full_fft_shape.AddDim(fft_shape[i - 1]); |
271 | } |
272 | OP_REQUIRES(ctx, full_fft_shape.num_elements() > 0, |
273 | errors::InvalidArgument("Obtained a FFT shape of 0 elements: " , |
274 | full_fft_shape.DebugString())); |
275 | |
276 | Tensor temp; |
277 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<ComplexT>::v(), |
278 | full_fft_shape, &temp)); |
279 | auto full_fft = temp.flat_inner_dims<ComplexT, FFTRank + 1>(); |
280 | |
281 | // Calculate the starting point and range of the source of |
282 | // negative frequency part. |
283 | auto neg_sizes = input_slice_sizes; |
284 | neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - input_slice_sizes[FFTRank]; |
285 | Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_target_indices; |
286 | neg_target_indices[FFTRank] = input_slice_sizes[FFTRank]; |
287 | |
288 | const Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> start_indices; |
289 | Eigen::DSizes<Eigen::DenseIndex, FFTRank + 1> neg_start_indices; |
290 | neg_start_indices[FFTRank] = 1; |
291 | |
292 | full_fft.slice(start_indices, input_slice_sizes).device(device) = |
293 | input.slice(start_indices, input_slice_sizes); |
294 | |
295 | // First, conduct IFFTs on outer dimensions. We save computation (and |
296 | // avoid touching uninitialized memory) by slicing full_fft to the |
297 | // subregion we wrote input to. |
298 | if (FFTRank > 1) { |
299 | const auto outer_axes = |
300 | Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1); |
301 | full_fft.slice(start_indices, input_slice_sizes).device(device) = |
302 | full_fft.slice(start_indices, input_slice_sizes) |
303 | .template fft<Eigen::BothParts, Eigen::FFT_REVERSE>(outer_axes); |
304 | } |
305 | |
306 | // Reconstruct the full FFT by appending reversed and conjugated |
307 | // spectrum as the negative frequency part. |
308 | Eigen::array<bool, FFTRank + 1> reverse_last_axis; |
309 | for (auto i = 0; i <= FFTRank; i++) { |
310 | reverse_last_axis[i] = i == FFTRank; |
311 | } |
312 | |
313 | if (neg_sizes[FFTRank] != 0) { |
314 | full_fft.slice(neg_target_indices, neg_sizes).device(device) = |
315 | full_fft.slice(neg_start_indices, neg_sizes) |
316 | .reverse(reverse_last_axis) |
317 | .conjugate(); |
318 | } |
319 | |
320 | auto inner_axis = Eigen::array<int, 1>{FFTRank}; |
321 | output.device(device) = |
322 | full_fft.template fft<Eigen::RealPart, Eigen::FFT_REVERSE>(inner_axis); |
323 | } |
324 | }; |
325 | |
326 | REGISTER_KERNEL_BUILDER(Name("FFT" ).Device(DEVICE_CPU), FFTCPU<true, false, 1>); |
327 | REGISTER_KERNEL_BUILDER(Name("IFFT" ).Device(DEVICE_CPU), |
328 | FFTCPU<false, false, 1>); |
329 | REGISTER_KERNEL_BUILDER(Name("FFT2D" ).Device(DEVICE_CPU), |
330 | FFTCPU<true, false, 2>); |
331 | REGISTER_KERNEL_BUILDER(Name("IFFT2D" ).Device(DEVICE_CPU), |
332 | FFTCPU<false, false, 2>); |
333 | REGISTER_KERNEL_BUILDER(Name("FFT3D" ).Device(DEVICE_CPU), |
334 | FFTCPU<true, false, 3>); |
335 | REGISTER_KERNEL_BUILDER(Name("IFFT3D" ).Device(DEVICE_CPU), |
336 | FFTCPU<false, false, 3>); |
337 | |
338 | REGISTER_KERNEL_BUILDER(Name("RFFT" ).Device(DEVICE_CPU), FFTCPU<true, true, 1>); |
339 | REGISTER_KERNEL_BUILDER(Name("IRFFT" ).Device(DEVICE_CPU), |
340 | FFTCPU<false, true, 1>); |
341 | REGISTER_KERNEL_BUILDER(Name("RFFT2D" ).Device(DEVICE_CPU), |
342 | FFTCPU<true, true, 2>); |
343 | REGISTER_KERNEL_BUILDER(Name("IRFFT2D" ).Device(DEVICE_CPU), |
344 | FFTCPU<false, true, 2>); |
345 | REGISTER_KERNEL_BUILDER(Name("RFFT3D" ).Device(DEVICE_CPU), |
346 | FFTCPU<true, true, 3>); |
347 | REGISTER_KERNEL_BUILDER(Name("IRFFT3D" ).Device(DEVICE_CPU), |
348 | FFTCPU<false, true, 3>); |
349 | |
350 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
351 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
352 | |
353 | namespace { |
354 | template <typename T> |
355 | se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { |
356 | se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); |
357 | se::DeviceMemory<T> typed(wrapped); |
358 | return typed; |
359 | } |
360 | |
361 | template <typename T> |
362 | se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) { |
363 | se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T)); |
364 | se::DeviceMemory<T> typed(wrapped); |
365 | return typed; |
366 | } |
367 | |
368 | // A class to provide scratch-space allocator for Stream-Executor Cufft |
369 | // callback. Tensorflow is responsible for releasing the temporary buffers after |
370 | // the kernel finishes. |
371 | // TODO(yangzihao): Refactor redundant code in subclasses of ScratchAllocator |
372 | // into base class. |
373 | class CufftScratchAllocator : public se::ScratchAllocator { |
374 | public: |
375 | ~CufftScratchAllocator() override {} |
376 | CufftScratchAllocator(int64_t memory_limit, OpKernelContext* context) |
377 | : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {} |
378 | int64_t GetMemoryLimitInBytes() override { return memory_limit_; } |
379 | se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes( |
380 | int64_t byte_size) override { |
381 | Tensor temporary_memory; |
382 | if (byte_size > memory_limit_) { |
383 | return se::port::StatusOr<se::DeviceMemory<uint8>>(); |
384 | } |
385 | AllocationAttributes allocation_attr; |
386 | allocation_attr.retry_on_failure = false; |
387 | Status allocation_status(context_->allocate_temp( |
388 | DT_UINT8, TensorShape({byte_size}), &temporary_memory, |
389 | AllocatorAttributes(), allocation_attr)); |
390 | if (!allocation_status.ok()) { |
391 | return se::port::StatusOr<se::DeviceMemory<uint8>>(); |
392 | } |
393 | // Hold the reference of the allocated tensors until the end of the |
394 | // allocator. |
395 | allocated_tensors_.push_back(temporary_memory); |
396 | total_byte_size_ += byte_size; |
397 | return se::port::StatusOr<se::DeviceMemory<uint8>>( |
398 | AsDeviceMemory(temporary_memory.flat<uint8>().data(), |
399 | temporary_memory.flat<uint8>().size())); |
400 | } |
401 | int64_t TotalByteSize() { return total_byte_size_; } |
402 | |
403 | private: |
404 | int64_t memory_limit_; |
405 | int64_t total_byte_size_; |
406 | OpKernelContext* context_; |
407 | std::vector<Tensor> allocated_tensors_; |
408 | }; |
409 | |
410 | } // end namespace |
411 | |
412 | int64_t GetCufftWorkspaceLimit(const string& envvar_in_mb, |
413 | int64_t default_value_in_bytes) { |
414 | const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str()); |
415 | if (workspace_limit_in_mb_str != nullptr && |
416 | strcmp(workspace_limit_in_mb_str, "" ) != 0) { |
417 | int64_t scratch_limit_in_mb = -1; |
418 | Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes, |
419 | &scratch_limit_in_mb); |
420 | if (!status.ok()) { |
421 | LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": " |
422 | << workspace_limit_in_mb_str; |
423 | } else { |
424 | return scratch_limit_in_mb * (1 << 20); |
425 | } |
426 | } |
427 | return default_value_in_bytes; |
428 | } |
429 | |
430 | class FFTGPUBase : public FFTBase { |
431 | public: |
432 | using FFTBase::FFTBase; |
433 | |
434 | protected: |
435 | static int64_t CufftScratchSize; |
436 | void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape, |
437 | Tensor* out) override { |
438 | auto* stream = ctx->op_device_context()->stream(); |
439 | OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available." )); |
440 | |
441 | const TensorShape& input_shape = in.shape(); |
442 | const TensorShape& output_shape = out->shape(); |
443 | |
444 | const int fft_rank = Rank(); |
445 | int batch_size = 1; |
446 | for (int i = 0; i < input_shape.dims() - fft_rank; ++i) { |
447 | batch_size *= input_shape.dim_size(i); |
448 | } |
449 | uint64 input_embed[3]; |
450 | const uint64 input_stride = 1; |
451 | uint64 input_distance = 1; |
452 | uint64 output_embed[3]; |
453 | const uint64 output_stride = 1; |
454 | uint64 output_distance = 1; |
455 | |
456 | for (int i = 0; i < fft_rank; ++i) { |
457 | auto dim_offset = input_shape.dims() - fft_rank + i; |
458 | input_embed[i] = input_shape.dim_size(dim_offset); |
459 | input_distance *= input_shape.dim_size(dim_offset); |
460 | output_embed[i] = output_shape.dim_size(dim_offset); |
461 | output_distance *= output_shape.dim_size(dim_offset); |
462 | } |
463 | |
464 | constexpr bool kInPlaceFft = false; |
465 | const bool is_complex128 = |
466 | in.dtype() == DT_COMPLEX128 || out->dtype() == DT_COMPLEX128; |
467 | |
468 | const auto kFftType = |
469 | IsReal() |
470 | ? (IsForward() |
471 | ? (is_complex128 ? se::fft::Type::kD2Z : se::fft::Type::kR2C) |
472 | : (is_complex128 ? se::fft::Type::kZ2D |
473 | : se::fft::Type::kC2R)) |
474 | : (IsForward() ? (is_complex128 ? se::fft::Type::kZ2ZForward |
475 | : se::fft::Type::kC2CForward) |
476 | : (is_complex128 ? se::fft::Type::kZ2ZInverse |
477 | : se::fft::Type::kC2CInverse)); |
478 | |
479 | CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx); |
480 | auto plan = |
481 | stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator( |
482 | stream, fft_rank, fft_shape, input_embed, input_stride, |
483 | input_distance, output_embed, output_stride, output_distance, |
484 | kFftType, kInPlaceFft, batch_size, &scratch_allocator); |
485 | OP_REQUIRES( |
486 | ctx, plan != nullptr, |
487 | errors::Internal( |
488 | "Failed to create cuFFT batched plan with scratch allocator" )); |
489 | |
490 | if (IsReal()) { |
491 | if (IsForward()) { |
492 | if (is_complex128) { |
493 | DCHECK_EQ(in.dtype(), DT_DOUBLE); |
494 | DCHECK_EQ(out->dtype(), DT_COMPLEX128); |
495 | DoFFTInternal<double, complex128>(ctx, stream, plan.get(), kFftType, |
496 | output_distance, in, out); |
497 | } else { |
498 | DCHECK_EQ(in.dtype(), DT_FLOAT); |
499 | DCHECK_EQ(out->dtype(), DT_COMPLEX64); |
500 | DoFFTInternal<float, complex64>(ctx, stream, plan.get(), kFftType, |
501 | output_distance, in, out); |
502 | } |
503 | } else { |
504 | if (is_complex128) { |
505 | DCHECK_EQ(in.dtype(), DT_COMPLEX128); |
506 | DCHECK_EQ(out->dtype(), DT_DOUBLE); |
507 | DoFFTInternal<complex128, double>(ctx, stream, plan.get(), kFftType, |
508 | output_distance, in, out); |
509 | } else { |
510 | DCHECK_EQ(in.dtype(), DT_COMPLEX64); |
511 | DCHECK_EQ(out->dtype(), DT_FLOAT); |
512 | DoFFTInternal<complex64, float>(ctx, stream, plan.get(), kFftType, |
513 | output_distance, in, out); |
514 | } |
515 | } |
516 | } else { |
517 | if (is_complex128) { |
518 | DCHECK_EQ(in.dtype(), DT_COMPLEX128); |
519 | DCHECK_EQ(out->dtype(), DT_COMPLEX128); |
520 | DoFFTInternal<complex128, complex128>(ctx, stream, plan.get(), kFftType, |
521 | output_distance, in, out); |
522 | } else { |
523 | DCHECK_EQ(in.dtype(), DT_COMPLEX64); |
524 | DCHECK_EQ(out->dtype(), DT_COMPLEX64); |
525 | DoFFTInternal<complex64, complex64>(ctx, stream, plan.get(), kFftType, |
526 | output_distance, in, out); |
527 | } |
528 | } |
529 | } |
530 | |
531 | private: |
532 | template <typename T> |
533 | struct RealTypeFromComplexType { |
534 | typedef T RealT; |
535 | }; |
536 | |
537 | template <typename T> |
538 | struct RealTypeFromComplexType<std::complex<T>> { |
539 | typedef T RealT; |
540 | }; |
541 | |
542 | template <typename InT, typename OutT> |
543 | void DoFFTInternal(OpKernelContext* ctx, se::Stream* stream, |
544 | se::fft::Plan* plan, const se::fft::Type fft_type, |
545 | const uint64 output_distance, const Tensor& in, |
546 | Tensor* out) { |
547 | const TensorShape& input_shape = in.shape(); |
548 | const TensorShape& output_shape = out->shape(); |
549 | auto src = |
550 | AsDeviceMemory<InT>(in.flat<InT>().data(), input_shape.num_elements()); |
551 | auto dst = AsDeviceMemory<OutT>(out->flat<OutT>().data(), |
552 | output_shape.num_elements()); |
553 | OP_REQUIRES( |
554 | ctx, stream->ThenFft(plan, src, &dst).ok(), |
555 | errors::Internal("fft failed : type=" , static_cast<int>(fft_type), |
556 | " in.shape=" , input_shape.DebugString())); |
557 | if (!IsForward()) { |
558 | typedef typename RealTypeFromComplexType<OutT>::RealT RealT; |
559 | RealT alpha = 1.0 / output_distance; |
560 | OP_REQUIRES( |
561 | ctx, |
562 | stream->ThenBlasScal(output_shape.num_elements(), alpha, &dst, 1) |
563 | .ok(), |
564 | errors::Internal("BlasScal failed : in.shape=" , |
565 | input_shape.DebugString())); |
566 | } |
567 | } |
568 | }; |
569 | |
570 | int64_t FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit( |
571 | // default value is in bytes despite the name of the environment variable |
572 | "TF_CUFFT_WORKSPACE_LIMIT_IN_MB" , 1LL << 32 // 4GB |
573 | ); |
574 | |
575 | template <bool Forward, bool _Real, int FFTRank> |
576 | class FFTGPU : public FFTGPUBase { |
577 | public: |
578 | static_assert(FFTRank >= 1 && FFTRank <= 3, |
579 | "Only 1D, 2D and 3D FFTs supported." ); |
580 | explicit FFTGPU(OpKernelConstruction* ctx) : FFTGPUBase(ctx) {} |
581 | |
582 | protected: |
583 | int Rank() const override { return FFTRank; } |
584 | bool IsForward() const override { return Forward; } |
585 | bool IsReal() const override { return _Real; } |
586 | }; |
587 | |
588 | // Register GPU kernels with priority 1 so that if a custom FFT CPU kernel is |
589 | // registered with priority 1 (to override the default Eigen CPU kernel), the |
590 | // CPU kernel does not outrank the GPU kernel. |
591 | REGISTER_KERNEL_BUILDER(Name("FFT" ).Device(DEVICE_GPU).Priority(1), |
592 | FFTGPU<true, false, 1>); |
593 | REGISTER_KERNEL_BUILDER(Name("IFFT" ).Device(DEVICE_GPU).Priority(1), |
594 | FFTGPU<false, false, 1>); |
595 | REGISTER_KERNEL_BUILDER(Name("FFT2D" ).Device(DEVICE_GPU).Priority(1), |
596 | FFTGPU<true, false, 2>); |
597 | REGISTER_KERNEL_BUILDER(Name("IFFT2D" ).Device(DEVICE_GPU).Priority(1), |
598 | FFTGPU<false, false, 2>); |
599 | REGISTER_KERNEL_BUILDER(Name("FFT3D" ).Device(DEVICE_GPU).Priority(1), |
600 | FFTGPU<true, false, 3>); |
601 | REGISTER_KERNEL_BUILDER(Name("IFFT3D" ).Device(DEVICE_GPU).Priority(1), |
602 | FFTGPU<false, false, 3>); |
603 | |
604 | REGISTER_KERNEL_BUILDER( |
605 | Name("RFFT" ).Device(DEVICE_GPU).HostMemory("fft_length" ).Priority(1), |
606 | FFTGPU<true, true, 1>); |
607 | REGISTER_KERNEL_BUILDER( |
608 | Name("IRFFT" ).Device(DEVICE_GPU).HostMemory("fft_length" ).Priority(1), |
609 | FFTGPU<false, true, 1>); |
610 | REGISTER_KERNEL_BUILDER( |
611 | Name("RFFT2D" ).Device(DEVICE_GPU).HostMemory("fft_length" ).Priority(1), |
612 | FFTGPU<true, true, 2>); |
613 | REGISTER_KERNEL_BUILDER( |
614 | Name("IRFFT2D" ).Device(DEVICE_GPU).HostMemory("fft_length" ).Priority(1), |
615 | FFTGPU<false, true, 2>); |
616 | REGISTER_KERNEL_BUILDER( |
617 | Name("RFFT3D" ).Device(DEVICE_GPU).HostMemory("fft_length" ).Priority(1), |
618 | FFTGPU<true, true, 3>); |
619 | REGISTER_KERNEL_BUILDER( |
620 | Name("IRFFT3D" ).Device(DEVICE_GPU).HostMemory("fft_length" ).Priority(1), |
621 | FFTGPU<false, true, 3>); |
622 | |
623 | // Deprecated kernels. |
624 | REGISTER_KERNEL_BUILDER(Name("BatchFFT" ).Device(DEVICE_GPU).Priority(1), |
625 | FFTGPU<true, false, 1>); |
626 | REGISTER_KERNEL_BUILDER(Name("BatchIFFT" ).Device(DEVICE_GPU).Priority(1), |
627 | FFTGPU<false, false, 1>); |
628 | REGISTER_KERNEL_BUILDER(Name("BatchFFT2D" ).Device(DEVICE_GPU).Priority(1), |
629 | FFTGPU<true, false, 2>); |
630 | REGISTER_KERNEL_BUILDER(Name("BatchIFFT2D" ).Device(DEVICE_GPU).Priority(1), |
631 | FFTGPU<false, false, 2>); |
632 | REGISTER_KERNEL_BUILDER(Name("BatchFFT3D" ).Device(DEVICE_GPU).Priority(1), |
633 | FFTGPU<true, false, 3>); |
634 | REGISTER_KERNEL_BUILDER(Name("BatchIFFT3D" ).Device(DEVICE_GPU).Priority(1), |
635 | FFTGPU<false, false, 3>); |
636 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
637 | |
638 | } // end namespace tensorflow |
639 | |