1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19#include "tensorflow/core/lib/core/bits.h"
20
21namespace tensorflow {
22
23namespace {
24
25using shape_inference::DimensionHandle;
26using shape_inference::InferenceContext;
27using shape_inference::ShapeHandle;
28
29Status DecodeWavShapeFn(InferenceContext* c) {
30 ShapeHandle unused;
31 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
32
33 DimensionHandle channels_dim;
34 int32_t desired_channels;
35 TF_RETURN_IF_ERROR(c->GetAttr("desired_channels", &desired_channels));
36 if (desired_channels == -1) {
37 channels_dim = c->UnknownDim();
38 } else {
39 if (desired_channels < 0) {
40 return errors::InvalidArgument("channels must be non-negative, got ",
41 desired_channels);
42 }
43 channels_dim = c->MakeDim(desired_channels);
44 }
45 DimensionHandle samples_dim;
46 int32_t desired_samples;
47 TF_RETURN_IF_ERROR(c->GetAttr("desired_samples", &desired_samples));
48 if (desired_samples == -1) {
49 samples_dim = c->UnknownDim();
50 } else {
51 if (desired_samples < 0) {
52 return errors::InvalidArgument("samples must be non-negative, got ",
53 desired_samples);
54 }
55 samples_dim = c->MakeDim(desired_samples);
56 }
57 c->set_output(0, c->MakeShape({samples_dim, channels_dim}));
58 c->set_output(1, c->Scalar());
59 return OkStatus();
60}
61
62Status EncodeWavShapeFn(InferenceContext* c) {
63 ShapeHandle unused;
64 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
65 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
66 c->set_output(0, c->Scalar());
67 return OkStatus();
68}
69
70Status SpectrogramShapeFn(InferenceContext* c) {
71 ShapeHandle input;
72 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
73 int32_t window_size;
74 TF_RETURN_IF_ERROR(c->GetAttr("window_size", &window_size));
75 int32_t stride;
76 TF_RETURN_IF_ERROR(c->GetAttr("stride", &stride));
77
78 DimensionHandle input_length = c->Dim(input, 0);
79 DimensionHandle input_channels = c->Dim(input, 1);
80
81 DimensionHandle output_length;
82 if (!c->ValueKnown(input_length)) {
83 output_length = c->UnknownDim();
84 } else {
85 const int64_t input_length_value = c->Value(input_length);
86 const int64_t length_minus_window = (input_length_value - window_size);
87 int64_t output_length_value;
88 if (length_minus_window < 0) {
89 output_length_value = 0;
90 } else {
91 output_length_value = 1 + (length_minus_window / stride);
92 }
93 output_length = c->MakeDim(output_length_value);
94 }
95
96 DimensionHandle output_channels =
97 c->MakeDim(1 + NextPowerOfTwo(window_size) / 2);
98 c->set_output(0,
99 c->MakeShape({input_channels, output_length, output_channels}));
100 return OkStatus();
101}
102
103Status MfccShapeFn(InferenceContext* c) {
104 ShapeHandle spectrogram;
105 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &spectrogram));
106 ShapeHandle unused;
107 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
108
109 int32_t dct_coefficient_count;
110 TF_RETURN_IF_ERROR(
111 c->GetAttr("dct_coefficient_count", &dct_coefficient_count));
112
113 DimensionHandle spectrogram_channels = c->Dim(spectrogram, 0);
114 DimensionHandle spectrogram_length = c->Dim(spectrogram, 1);
115
116 DimensionHandle output_channels = c->MakeDim(dct_coefficient_count);
117
118 c->set_output(0, c->MakeShape({spectrogram_channels, spectrogram_length,
119 output_channels}));
120 return OkStatus();
121}
122
123} // namespace
124
125REGISTER_OP("DecodeWav")
126 .Input("contents: string")
127 .Attr("desired_channels: int = -1")
128 .Attr("desired_samples: int = -1")
129 .Output("audio: float")
130 .Output("sample_rate: int32")
131 .SetShapeFn(DecodeWavShapeFn);
132
133REGISTER_OP("EncodeWav")
134 .Input("audio: float")
135 .Input("sample_rate: int32")
136 .Output("contents: string")
137 .SetShapeFn(EncodeWavShapeFn);
138
139REGISTER_OP("AudioSpectrogram")
140 .Input("input: float")
141 .Attr("window_size: int")
142 .Attr("stride: int")
143 .Attr("magnitude_squared: bool = false")
144 .Output("spectrogram: float")
145 .SetShapeFn(SpectrogramShapeFn);
146
147REGISTER_OP("Mfcc")
148 .Input("spectrogram: float")
149 .Input("sample_rate: int32")
150 .Attr("upper_frequency_limit: float = 4000")
151 .Attr("lower_frequency_limit: float = 20")
152 .Attr("filterbank_channel_count: int = 40")
153 .Attr("dct_coefficient_count: int = 13")
154 .Output("output: float")
155 .SetShapeFn(MfccShapeFn);
156
157} // namespace tensorflow
158