1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/mfcc.h"
16
17#include <stddef.h>
18#include <stdint.h>
19
20#include <vector>
21
22#include "flatbuffers/flexbuffers.h" // from @flatbuffers
23#include "tensorflow/lite/c/common.h"
24#include "tensorflow/lite/kernels/internal/compatibility.h"
25#include "tensorflow/lite/kernels/internal/mfcc_dct.h"
26#include "tensorflow/lite/kernels/internal/mfcc_mel_filterbank.h"
27#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
29#include "tensorflow/lite/kernels/internal/tensor.h"
30#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
31#include "tensorflow/lite/kernels/kernel_util.h"
32
33namespace tflite {
34namespace ops {
35namespace custom {
36namespace mfcc {
37
38enum KernelType {
39 kReference,
40};
41
42typedef struct {
43 float upper_frequency_limit;
44 float lower_frequency_limit;
45 int filterbank_channel_count;
46 int dct_coefficient_count;
47} TfLiteMfccParams;
48
49constexpr int kInputTensorWav = 0;
50constexpr int kInputTensorRate = 1;
51constexpr int kOutputTensor = 0;
52
53void* Init(TfLiteContext* context, const char* buffer, size_t length) {
54 auto* data = new TfLiteMfccParams;
55
56 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
57
58 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
59 data->upper_frequency_limit = m["upper_frequency_limit"].AsInt64();
60 data->lower_frequency_limit = m["lower_frequency_limit"].AsInt64();
61 data->filterbank_channel_count = m["filterbank_channel_count"].AsInt64();
62 data->dct_coefficient_count = m["dct_coefficient_count"].AsInt64();
63 return data;
64}
65
66void Free(TfLiteContext* context, void* buffer) {
67 delete reinterpret_cast<TfLiteMfccParams*>(buffer);
68}
69
70TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
71 auto* params = reinterpret_cast<TfLiteMfccParams*>(node->user_data);
72
73 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
74 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
75
76 const TfLiteTensor* input_wav;
77 TF_LITE_ENSURE_OK(context,
78 GetInputSafe(context, node, kInputTensorWav, &input_wav));
79 const TfLiteTensor* input_rate;
80 TF_LITE_ENSURE_OK(context,
81 GetInputSafe(context, node, kInputTensorRate, &input_rate));
82 TfLiteTensor* output;
83 TF_LITE_ENSURE_OK(context,
84 GetOutputSafe(context, node, kOutputTensor, &output));
85
86 TF_LITE_ENSURE_EQ(context, NumDimensions(input_wav), 3);
87 TF_LITE_ENSURE_EQ(context, NumElements(input_rate), 1);
88
89 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
90 TF_LITE_ENSURE_TYPES_EQ(context, input_wav->type, output->type);
91 TF_LITE_ENSURE_TYPES_EQ(context, input_rate->type, kTfLiteInt32);
92
93 TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
94 output_size->data[0] = input_wav->dims->data[0];
95 output_size->data[1] = input_wav->dims->data[1];
96 output_size->data[2] = params->dct_coefficient_count;
97
98 return context->ResizeTensor(context, output, output_size);
99}
100
101// Input is a single squared-magnitude spectrogram frame. The input spectrum
102// is converted to linear magnitude and weighted into bands using a
103// triangular mel filterbank, and a discrete cosine transform (DCT) of the
104// values is taken. Output is populated with the lowest dct_coefficient_count
105// of these values.
106template <KernelType kernel_type>
107TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
108 auto* params = reinterpret_cast<TfLiteMfccParams*>(node->user_data);
109
110 const TfLiteTensor* input_wav;
111 TF_LITE_ENSURE_OK(context,
112 GetInputSafe(context, node, kInputTensorWav, &input_wav));
113 const TfLiteTensor* input_rate;
114 TF_LITE_ENSURE_OK(context,
115 GetInputSafe(context, node, kInputTensorRate, &input_rate));
116 TfLiteTensor* output;
117 TF_LITE_ENSURE_OK(context,
118 GetOutputSafe(context, node, kOutputTensor, &output));
119
120 const int32 sample_rate = *GetTensorData<int>(input_rate);
121
122 const int spectrogram_channels = input_wav->dims->data[2];
123 const int spectrogram_samples = input_wav->dims->data[1];
124 const int audio_channels = input_wav->dims->data[0];
125
126 internal::Mfcc mfcc;
127 mfcc.set_upper_frequency_limit(params->upper_frequency_limit);
128 mfcc.set_lower_frequency_limit(params->lower_frequency_limit);
129 mfcc.set_filterbank_channel_count(params->filterbank_channel_count);
130 mfcc.set_dct_coefficient_count(params->dct_coefficient_count);
131
132 mfcc.Initialize(spectrogram_channels, sample_rate);
133
134 const float* spectrogram_flat = GetTensorData<float>(input_wav);
135 float* output_flat = GetTensorData<float>(output);
136
137 for (int audio_channel = 0; audio_channel < audio_channels; ++audio_channel) {
138 for (int spectrogram_sample = 0; spectrogram_sample < spectrogram_samples;
139 ++spectrogram_sample) {
140 const float* sample_data =
141 spectrogram_flat +
142 (audio_channel * spectrogram_samples * spectrogram_channels) +
143 (spectrogram_sample * spectrogram_channels);
144 std::vector<double> mfcc_input(sample_data,
145 sample_data + spectrogram_channels);
146 std::vector<double> mfcc_output;
147 mfcc.Compute(mfcc_input, &mfcc_output);
148 TF_LITE_ENSURE_EQ(context, params->dct_coefficient_count,
149 mfcc_output.size());
150 float* output_data = output_flat +
151 (audio_channel * spectrogram_samples *
152 params->dct_coefficient_count) +
153 (spectrogram_sample * params->dct_coefficient_count);
154 for (int i = 0; i < params->dct_coefficient_count; ++i) {
155 output_data[i] = mfcc_output[i];
156 }
157 }
158 }
159
160 return kTfLiteOk;
161}
162
163} // namespace mfcc
164
165TfLiteRegistration* Register_MFCC() {
166 static TfLiteRegistration r = {mfcc::Init, mfcc::Free, mfcc::Prepare,
167 mfcc::Eval<mfcc::kReference>};
168 return &r;
169}
170
171} // namespace custom
172} // namespace ops
173} // namespace tflite
174