1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <math.h>
17#include <stddef.h>
18#include <stdint.h>
19
20#include <vector>
21
22#include "flatbuffers/flexbuffers.h" // from @flatbuffers
23#include "tensorflow/lite/c/common.h"
24#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
25#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
26#include "tensorflow/lite/kernels/internal/spectrogram.h"
27#include "tensorflow/lite/kernels/internal/tensor.h"
28#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29#include "tensorflow/lite/kernels/kernel_util.h"
30
31namespace tflite {
32namespace ops {
33namespace custom {
34namespace audio_spectrogram {
35
36constexpr int kInputTensor = 0;
37constexpr int kOutputTensor = 0;
38
39enum KernelType {
40 kReference,
41};
42
43typedef struct {
44 int window_size;
45 int stride;
46 bool magnitude_squared;
47 int output_height;
48 internal::Spectrogram* spectrogram;
49} TfLiteAudioSpectrogramParams;
50
51void* Init(TfLiteContext* context, const char* buffer, size_t length) {
52 auto* data = new TfLiteAudioSpectrogramParams;
53
54 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
55
56 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
57 data->window_size = m["window_size"].AsInt64();
58 data->stride = m["stride"].AsInt64();
59 data->magnitude_squared = m["magnitude_squared"].AsBool();
60
61 data->spectrogram = new internal::Spectrogram;
62
63 return data;
64}
65
66void Free(TfLiteContext* context, void* buffer) {
67 auto* params = reinterpret_cast<TfLiteAudioSpectrogramParams*>(buffer);
68 delete params->spectrogram;
69 delete params;
70}
71
72TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
73 auto* params =
74 reinterpret_cast<TfLiteAudioSpectrogramParams*>(node->user_data);
75
76 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
77 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
78
79 const TfLiteTensor* input;
80 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
81 TfLiteTensor* output;
82 TF_LITE_ENSURE_OK(context,
83 GetOutputSafe(context, node, kOutputTensor, &output));
84
85 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
86
87 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
88 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
89
90 TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size,
91 params->stride));
92 const int64_t sample_count = input->dims->data[0];
93 const int64_t length_minus_window = (sample_count - params->window_size);
94 if (length_minus_window < 0) {
95 params->output_height = 0;
96 } else {
97 params->output_height = 1 + (length_minus_window / params->stride);
98 }
99 TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
100 output_size->data[0] = input->dims->data[1];
101 output_size->data[1] = params->output_height;
102 output_size->data[2] = params->spectrogram->output_frequency_channels();
103
104 return context->ResizeTensor(context, output, output_size);
105}
106
107template <KernelType kernel_type>
108TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
109 auto* params =
110 reinterpret_cast<TfLiteAudioSpectrogramParams*>(node->user_data);
111
112 const TfLiteTensor* input;
113 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
114 TfLiteTensor* output;
115 TF_LITE_ENSURE_OK(context,
116 GetOutputSafe(context, node, kOutputTensor, &output));
117
118 TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size,
119 params->stride));
120
121 const float* input_data = GetTensorData<float>(input);
122
123 const int64_t sample_count = input->dims->data[0];
124 const int64_t channel_count = input->dims->data[1];
125
126 const int64_t output_width = params->spectrogram->output_frequency_channels();
127
128 float* output_flat = GetTensorData<float>(output);
129
130 std::vector<float> input_for_channel(sample_count);
131 for (int64_t channel = 0; channel < channel_count; ++channel) {
132 float* output_slice =
133 output_flat + (channel * params->output_height * output_width);
134 for (int i = 0; i < sample_count; ++i) {
135 input_for_channel[i] = input_data[i * channel_count + channel];
136 }
137 std::vector<std::vector<float>> spectrogram_output;
138 TF_LITE_ENSURE(context,
139 params->spectrogram->ComputeSquaredMagnitudeSpectrogram(
140 input_for_channel, &spectrogram_output));
141 TF_LITE_ENSURE_EQ(context, spectrogram_output.size(),
142 params->output_height);
143 TF_LITE_ENSURE(context, spectrogram_output.empty() ||
144 (spectrogram_output[0].size() == output_width));
145 for (int row_index = 0; row_index < params->output_height; ++row_index) {
146 const std::vector<float>& spectrogram_row = spectrogram_output[row_index];
147 TF_LITE_ENSURE_EQ(context, spectrogram_row.size(), output_width);
148 float* output_row = output_slice + (row_index * output_width);
149 if (params->magnitude_squared) {
150 for (int i = 0; i < output_width; ++i) {
151 output_row[i] = spectrogram_row[i];
152 }
153 } else {
154 for (int i = 0; i < output_width; ++i) {
155 output_row[i] = sqrtf(spectrogram_row[i]);
156 }
157 }
158 }
159 }
160 return kTfLiteOk;
161}
162
163} // namespace audio_spectrogram
164
165TfLiteRegistration* Register_AUDIO_SPECTROGRAM() {
166 static TfLiteRegistration r = {
167 audio_spectrogram::Init, audio_spectrogram::Free,
168 audio_spectrogram::Prepare,
169 audio_spectrogram::Eval<audio_spectrogram::kReference>};
170 return &r;
171}
172
173} // namespace custom
174} // namespace ops
175} // namespace tflite
176