1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/spectrogram.h"
17
18#include <math.h>
19
20#include "third_party/fft2d/fft.h"
21#include "tensorflow/core/lib/core/bits.h"
22
23namespace tensorflow {
24
25using std::complex;
26
27namespace {
28// Returns the default Hann window function for the spectrogram.
29void GetPeriodicHann(int window_length, std::vector<double>* window) {
30 // Some platforms don't have M_PI, so define a local constant here.
31 const double pi = std::atan(1) * 4;
32 window->resize(window_length);
33 for (int i = 0; i < window_length; ++i) {
34 (*window)[i] = 0.5 - 0.5 * cos((2 * pi * i) / window_length);
35 }
36}
37} // namespace
38
39bool Spectrogram::Initialize(int window_length, int step_length) {
40 std::vector<double> window;
41 GetPeriodicHann(window_length, &window);
42 return Initialize(window, step_length);
43}
44
45bool Spectrogram::Initialize(const std::vector<double>& window,
46 int step_length) {
47 window_length_ = window.size();
48 window_ = window; // Copy window.
49 if (window_length_ < 2) {
50 LOG(ERROR) << "Window length too short.";
51 initialized_ = false;
52 return false;
53 }
54
55 step_length_ = step_length;
56 if (step_length_ < 1) {
57 LOG(ERROR) << "Step length must be positive.";
58 initialized_ = false;
59 return false;
60 }
61
62 fft_length_ = NextPowerOfTwo(window_length_);
63 CHECK(fft_length_ >= window_length_);
64 output_frequency_channels_ = 1 + fft_length_ / 2;
65
66 // Allocate 2 more than what rdft needs, so we can rationalize the layout.
67 fft_input_output_.resize(fft_length_ + 2);
68
69 int half_fft_length = fft_length_ / 2;
70 fft_double_working_area_.resize(half_fft_length);
71 fft_integer_working_area_.resize(2 + static_cast<int>(sqrt(half_fft_length)));
72 initialized_ = true;
73 if (!Reset()) {
74 LOG(ERROR) << "Failed to Reset()";
75 return false;
76 }
77 return true;
78}
79
80bool Spectrogram::Reset() {
81 if (!initialized_) {
82 LOG(ERROR) << "Initialize() has to be called, before Reset().";
83 return false;
84 }
85 std::fill(fft_double_working_area_.begin(), fft_double_working_area_.end(),
86 0.0);
87 std::fill(fft_integer_working_area_.begin(), fft_integer_working_area_.end(),
88 0);
89
90 // Set flag element to ensure that the working areas are initialized
91 // on the first call to cdft. It's redundant given the assign above,
92 // but keep it as a reminder.
93 fft_integer_working_area_[0] = 0;
94 input_queue_.clear();
95 samples_to_next_step_ = window_length_;
96 return true;
97}
98
99template <class InputSample, class OutputSample>
100bool Spectrogram::ComputeComplexSpectrogram(
101 const std::vector<InputSample>& input,
102 std::vector<std::vector<complex<OutputSample>>>* output) {
103 if (!initialized_) {
104 LOG(ERROR) << "ComputeComplexSpectrogram() called before successful call "
105 << "to Initialize().";
106 return false;
107 }
108 CHECK(output);
109 output->clear();
110 int input_start = 0;
111 while (GetNextWindowOfSamples(input, &input_start)) {
112 DCHECK_EQ(input_queue_.size(), window_length_);
113 ProcessCoreFFT(); // Processes input_queue_ to fft_input_output_.
114 // Add a new slice vector onto the output, to save new result to.
115 output->resize(output->size() + 1);
116 // Get a reference to the newly added slice to fill in.
117 auto& spectrogram_slice = output->back();
118 spectrogram_slice.resize(output_frequency_channels_);
119 for (int i = 0; i < output_frequency_channels_; ++i) {
120 // This will convert double to float if it needs to.
121 spectrogram_slice[i] = complex<OutputSample>(
122 fft_input_output_[2 * i], fft_input_output_[2 * i + 1]);
123 }
124 }
125 return true;
126}
127// Instantiate it four ways:
128template bool Spectrogram::ComputeComplexSpectrogram(
129 const std::vector<float>& input, std::vector<std::vector<complex<float>>>*);
130template bool Spectrogram::ComputeComplexSpectrogram(
131 const std::vector<double>& input,
132 std::vector<std::vector<complex<float>>>*);
133template bool Spectrogram::ComputeComplexSpectrogram(
134 const std::vector<float>& input,
135 std::vector<std::vector<complex<double>>>*);
136template bool Spectrogram::ComputeComplexSpectrogram(
137 const std::vector<double>& input,
138 std::vector<std::vector<complex<double>>>*);
139
140template <class InputSample, class OutputSample>
141bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
142 const std::vector<InputSample>& input,
143 std::vector<std::vector<OutputSample>>* output) {
144 if (!initialized_) {
145 LOG(ERROR) << "ComputeSquaredMagnitudeSpectrogram() called before "
146 << "successful call to Initialize().";
147 return false;
148 }
149 CHECK(output);
150 output->clear();
151 int input_start = 0;
152 while (GetNextWindowOfSamples(input, &input_start)) {
153 DCHECK_EQ(input_queue_.size(), window_length_);
154 ProcessCoreFFT(); // Processes input_queue_ to fft_input_output_.
155 // Add a new slice vector onto the output, to save new result to.
156 output->resize(output->size() + 1);
157 // Get a reference to the newly added slice to fill in.
158 auto& spectrogram_slice = output->back();
159 spectrogram_slice.resize(output_frequency_channels_);
160 for (int i = 0; i < output_frequency_channels_; ++i) {
161 // Similar to the Complex case, except storing the norm.
162 // But the norm function is known to be a performance killer,
163 // so do it this way with explicit real and imaginary temps.
164 const double re = fft_input_output_[2 * i];
165 const double im = fft_input_output_[2 * i + 1];
166 // Which finally converts double to float if it needs to.
167 spectrogram_slice[i] = re * re + im * im;
168 }
169 }
170 return true;
171}
172// Instantiate it four ways:
173template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
174 const std::vector<float>& input, std::vector<std::vector<float>>*);
175template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
176 const std::vector<double>& input, std::vector<std::vector<float>>*);
177template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
178 const std::vector<float>& input, std::vector<std::vector<double>>*);
179template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
180 const std::vector<double>& input, std::vector<std::vector<double>>*);
181
182// Return true if a full window of samples is prepared; manage the queue.
183template <class InputSample>
184bool Spectrogram::GetNextWindowOfSamples(const std::vector<InputSample>& input,
185 int* input_start) {
186 auto input_it = input.begin() + *input_start;
187 int input_remaining = input.end() - input_it;
188 if (samples_to_next_step_ > input_remaining) {
189 // Copy in as many samples are left and return false, no full window.
190 input_queue_.insert(input_queue_.end(), input_it, input.end());
191 *input_start += input_remaining; // Increases it to input.size().
192 samples_to_next_step_ -= input_remaining;
193 return false; // Not enough for a full window.
194 } else {
195 // Copy just enough into queue to make a new window, then trim the
196 // front off the queue to make it window-sized.
197 input_queue_.insert(input_queue_.end(), input_it,
198 input_it + samples_to_next_step_);
199 *input_start += samples_to_next_step_;
200 input_queue_.erase(
201 input_queue_.begin(),
202 input_queue_.begin() + input_queue_.size() - window_length_);
203 DCHECK_EQ(window_length_, input_queue_.size());
204 samples_to_next_step_ = step_length_; // Be ready for next time.
205 return true; // Yes, input_queue_ now contains exactly a window-full.
206 }
207}
208
209void Spectrogram::ProcessCoreFFT() {
210 for (int j = 0; j < window_length_; ++j) {
211 fft_input_output_[j] = input_queue_[j] * window_[j];
212 }
213 // Zero-pad the rest of the input buffer.
214 for (int j = window_length_; j < fft_length_; ++j) {
215 fft_input_output_[j] = 0.0;
216 }
217 const int kForwardFFT = 1; // 1 means forward; -1 reverse.
218 // This real FFT is a fair amount faster than using cdft here.
219 rdft(fft_length_, kForwardFFT, &fft_input_output_[0],
220 &fft_integer_working_area_[0], &fft_double_working_area_[0]);
221 // Make rdft result look like cdft result;
222 // unpack the last real value from the first position's imag slot.
223 fft_input_output_[fft_length_] = fft_input_output_[1];
224 fft_input_output_[fft_length_ + 1] = 0;
225 fft_input_output_[1] = 0;
226}
227
228} // namespace tensorflow
229