1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <math.h> |
17 | #include <stddef.h> |
18 | #include <stdint.h> |
19 | |
20 | #include <vector> |
21 | |
22 | #include "flatbuffers/flexbuffers.h" // from @flatbuffers |
23 | #include "tensorflow/lite/c/common.h" |
24 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
25 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
26 | #include "tensorflow/lite/kernels/internal/spectrogram.h" |
27 | #include "tensorflow/lite/kernels/internal/tensor.h" |
28 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
29 | #include "tensorflow/lite/kernels/kernel_util.h" |
30 | |
31 | namespace tflite { |
32 | namespace ops { |
33 | namespace custom { |
34 | namespace audio_spectrogram { |
35 | |
36 | constexpr int kInputTensor = 0; |
37 | constexpr int kOutputTensor = 0; |
38 | |
39 | enum KernelType { |
40 | kReference, |
41 | }; |
42 | |
43 | typedef struct { |
44 | int window_size; |
45 | int stride; |
46 | bool magnitude_squared; |
47 | int output_height; |
48 | internal::Spectrogram* spectrogram; |
49 | } TfLiteAudioSpectrogramParams; |
50 | |
51 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
52 | auto* data = new TfLiteAudioSpectrogramParams; |
53 | |
54 | const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer); |
55 | |
56 | const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); |
57 | data->window_size = m["window_size" ].AsInt64(); |
58 | data->stride = m["stride" ].AsInt64(); |
59 | data->magnitude_squared = m["magnitude_squared" ].AsBool(); |
60 | |
61 | data->spectrogram = new internal::Spectrogram; |
62 | |
63 | return data; |
64 | } |
65 | |
66 | void Free(TfLiteContext* context, void* buffer) { |
67 | auto* params = reinterpret_cast<TfLiteAudioSpectrogramParams*>(buffer); |
68 | delete params->spectrogram; |
69 | delete params; |
70 | } |
71 | |
72 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
73 | auto* params = |
74 | reinterpret_cast<TfLiteAudioSpectrogramParams*>(node->user_data); |
75 | |
76 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
77 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
78 | |
79 | const TfLiteTensor* input; |
80 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
81 | TfLiteTensor* output; |
82 | TF_LITE_ENSURE_OK(context, |
83 | GetOutputSafe(context, node, kOutputTensor, &output)); |
84 | |
85 | TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2); |
86 | |
87 | TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32); |
88 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); |
89 | |
90 | TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size, |
91 | params->stride)); |
92 | const int64_t sample_count = input->dims->data[0]; |
93 | const int64_t length_minus_window = (sample_count - params->window_size); |
94 | if (length_minus_window < 0) { |
95 | params->output_height = 0; |
96 | } else { |
97 | params->output_height = 1 + (length_minus_window / params->stride); |
98 | } |
99 | TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); |
100 | output_size->data[0] = input->dims->data[1]; |
101 | output_size->data[1] = params->output_height; |
102 | output_size->data[2] = params->spectrogram->output_frequency_channels(); |
103 | |
104 | return context->ResizeTensor(context, output, output_size); |
105 | } |
106 | |
107 | template <KernelType kernel_type> |
108 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
109 | auto* params = |
110 | reinterpret_cast<TfLiteAudioSpectrogramParams*>(node->user_data); |
111 | |
112 | const TfLiteTensor* input; |
113 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
114 | TfLiteTensor* output; |
115 | TF_LITE_ENSURE_OK(context, |
116 | GetOutputSafe(context, node, kOutputTensor, &output)); |
117 | |
118 | TF_LITE_ENSURE(context, params->spectrogram->Initialize(params->window_size, |
119 | params->stride)); |
120 | |
121 | const float* input_data = GetTensorData<float>(input); |
122 | |
123 | const int64_t sample_count = input->dims->data[0]; |
124 | const int64_t channel_count = input->dims->data[1]; |
125 | |
126 | const int64_t output_width = params->spectrogram->output_frequency_channels(); |
127 | |
128 | float* output_flat = GetTensorData<float>(output); |
129 | |
130 | std::vector<float> input_for_channel(sample_count); |
131 | for (int64_t channel = 0; channel < channel_count; ++channel) { |
132 | float* output_slice = |
133 | output_flat + (channel * params->output_height * output_width); |
134 | for (int i = 0; i < sample_count; ++i) { |
135 | input_for_channel[i] = input_data[i * channel_count + channel]; |
136 | } |
137 | std::vector<std::vector<float>> spectrogram_output; |
138 | TF_LITE_ENSURE(context, |
139 | params->spectrogram->ComputeSquaredMagnitudeSpectrogram( |
140 | input_for_channel, &spectrogram_output)); |
141 | TF_LITE_ENSURE_EQ(context, spectrogram_output.size(), |
142 | params->output_height); |
143 | TF_LITE_ENSURE(context, spectrogram_output.empty() || |
144 | (spectrogram_output[0].size() == output_width)); |
145 | for (int row_index = 0; row_index < params->output_height; ++row_index) { |
146 | const std::vector<float>& spectrogram_row = spectrogram_output[row_index]; |
147 | TF_LITE_ENSURE_EQ(context, spectrogram_row.size(), output_width); |
148 | float* output_row = output_slice + (row_index * output_width); |
149 | if (params->magnitude_squared) { |
150 | for (int i = 0; i < output_width; ++i) { |
151 | output_row[i] = spectrogram_row[i]; |
152 | } |
153 | } else { |
154 | for (int i = 0; i < output_width; ++i) { |
155 | output_row[i] = sqrtf(spectrogram_row[i]); |
156 | } |
157 | } |
158 | } |
159 | } |
160 | return kTfLiteOk; |
161 | } |
162 | |
163 | } // namespace audio_spectrogram |
164 | |
165 | TfLiteRegistration* Register_AUDIO_SPECTROGRAM() { |
166 | static TfLiteRegistration r = { |
167 | audio_spectrogram::Init, audio_spectrogram::Free, |
168 | audio_spectrogram::Prepare, |
169 | audio_spectrogram::Eval<audio_spectrogram::kReference>}; |
170 | return &r; |
171 | } |
172 | |
173 | } // namespace custom |
174 | } // namespace ops |
175 | } // namespace tflite |
176 | |