1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/audio_ops.cc |
17 | |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/framework/register_types.h" |
20 | #include "tensorflow/core/framework/tensor.h" |
21 | #include "tensorflow/core/framework/tensor_shape.h" |
22 | #include "tensorflow/core/framework/types.h" |
23 | #include "tensorflow/core/kernels/mfcc.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | // Create a speech fingerpring from spectrogram data. |
29 | class MfccOp : public OpKernel { |
30 | public: |
31 | explicit MfccOp(OpKernelConstruction* context) : OpKernel(context) { |
32 | OP_REQUIRES_OK(context, context->GetAttr("upper_frequency_limit" , |
33 | &upper_frequency_limit_)); |
34 | OP_REQUIRES_OK(context, context->GetAttr("lower_frequency_limit" , |
35 | &lower_frequency_limit_)); |
36 | OP_REQUIRES_OK(context, context->GetAttr("filterbank_channel_count" , |
37 | &filterbank_channel_count_)); |
38 | OP_REQUIRES_OK(context, context->GetAttr("dct_coefficient_count" , |
39 | &dct_coefficient_count_)); |
40 | } |
41 | |
42 | void Compute(OpKernelContext* context) override { |
43 | const Tensor& spectrogram = context->input(0); |
44 | OP_REQUIRES(context, spectrogram.dims() == 3, |
45 | errors::InvalidArgument("spectrogram must be 3-dimensional" , |
46 | spectrogram.shape().DebugString())); |
47 | const Tensor& sample_rate_tensor = context->input(1); |
48 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(sample_rate_tensor.shape()), |
49 | errors::InvalidArgument( |
50 | "Input sample_rate should be a scalar tensor, got " , |
51 | sample_rate_tensor.shape().DebugString(), " instead." )); |
52 | const int32_t sample_rate = sample_rate_tensor.scalar<int32>()(); |
53 | |
54 | const int spectrogram_channels = spectrogram.dim_size(2); |
55 | const int spectrogram_samples = spectrogram.dim_size(1); |
56 | const int audio_channels = spectrogram.dim_size(0); |
57 | |
58 | Mfcc mfcc; |
59 | mfcc.set_upper_frequency_limit(upper_frequency_limit_); |
60 | mfcc.set_lower_frequency_limit(lower_frequency_limit_); |
61 | mfcc.set_filterbank_channel_count(filterbank_channel_count_); |
62 | mfcc.set_dct_coefficient_count(dct_coefficient_count_); |
63 | OP_REQUIRES(context, mfcc.Initialize(spectrogram_channels, sample_rate), |
64 | errors::InvalidArgument( |
65 | "Mfcc initialization failed for channel count " , |
66 | spectrogram_channels, " and sample rate " , sample_rate)); |
67 | |
68 | Tensor* output_tensor = nullptr; |
69 | OP_REQUIRES_OK(context, |
70 | context->allocate_output( |
71 | 0, |
72 | TensorShape({audio_channels, spectrogram_samples, |
73 | dct_coefficient_count_}), |
74 | &output_tensor)); |
75 | |
76 | const float* spectrogram_flat = spectrogram.flat<float>().data(); |
77 | float* output_flat = output_tensor->flat<float>().data(); |
78 | |
79 | for (int audio_channel = 0; audio_channel < audio_channels; |
80 | ++audio_channel) { |
81 | for (int spectrogram_sample = 0; spectrogram_sample < spectrogram_samples; |
82 | ++spectrogram_sample) { |
83 | const float* sample_data = |
84 | spectrogram_flat + |
85 | (audio_channel * spectrogram_samples * spectrogram_channels) + |
86 | (spectrogram_sample * spectrogram_channels); |
87 | std::vector<double> mfcc_input(sample_data, |
88 | sample_data + spectrogram_channels); |
89 | std::vector<double> mfcc_output; |
90 | mfcc.Compute(mfcc_input, &mfcc_output); |
91 | DCHECK_EQ(dct_coefficient_count_, mfcc_output.size()); |
92 | float* output_data = |
93 | output_flat + |
94 | (audio_channel * spectrogram_samples * dct_coefficient_count_) + |
95 | (spectrogram_sample * dct_coefficient_count_); |
96 | for (int i = 0; i < dct_coefficient_count_; ++i) { |
97 | output_data[i] = mfcc_output[i]; |
98 | } |
99 | } |
100 | } |
101 | } |
102 | |
103 | private: |
104 | float upper_frequency_limit_; |
105 | float lower_frequency_limit_; |
106 | int32 filterbank_channel_count_; |
107 | int32 dct_coefficient_count_; |
108 | }; |
109 | REGISTER_KERNEL_BUILDER(Name("Mfcc" ).Device(DEVICE_CPU), MfccOp); |
110 | |
111 | } // namespace tensorflow |
112 | |