1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | #include "tensorflow/core/lib/core/bits.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | namespace { |
24 | |
25 | using shape_inference::DimensionHandle; |
26 | using shape_inference::InferenceContext; |
27 | using shape_inference::ShapeHandle; |
28 | |
29 | Status DecodeWavShapeFn(InferenceContext* c) { |
30 | ShapeHandle unused; |
31 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
32 | |
33 | DimensionHandle channels_dim; |
34 | int32_t desired_channels; |
35 | TF_RETURN_IF_ERROR(c->GetAttr("desired_channels" , &desired_channels)); |
36 | if (desired_channels == -1) { |
37 | channels_dim = c->UnknownDim(); |
38 | } else { |
39 | if (desired_channels < 0) { |
40 | return errors::InvalidArgument("channels must be non-negative, got " , |
41 | desired_channels); |
42 | } |
43 | channels_dim = c->MakeDim(desired_channels); |
44 | } |
45 | DimensionHandle samples_dim; |
46 | int32_t desired_samples; |
47 | TF_RETURN_IF_ERROR(c->GetAttr("desired_samples" , &desired_samples)); |
48 | if (desired_samples == -1) { |
49 | samples_dim = c->UnknownDim(); |
50 | } else { |
51 | if (desired_samples < 0) { |
52 | return errors::InvalidArgument("samples must be non-negative, got " , |
53 | desired_samples); |
54 | } |
55 | samples_dim = c->MakeDim(desired_samples); |
56 | } |
57 | c->set_output(0, c->MakeShape({samples_dim, channels_dim})); |
58 | c->set_output(1, c->Scalar()); |
59 | return OkStatus(); |
60 | } |
61 | |
62 | Status EncodeWavShapeFn(InferenceContext* c) { |
63 | ShapeHandle unused; |
64 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); |
65 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
66 | c->set_output(0, c->Scalar()); |
67 | return OkStatus(); |
68 | } |
69 | |
70 | Status SpectrogramShapeFn(InferenceContext* c) { |
71 | ShapeHandle input; |
72 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); |
73 | int32_t window_size; |
74 | TF_RETURN_IF_ERROR(c->GetAttr("window_size" , &window_size)); |
75 | int32_t stride; |
76 | TF_RETURN_IF_ERROR(c->GetAttr("stride" , &stride)); |
77 | |
78 | DimensionHandle input_length = c->Dim(input, 0); |
79 | DimensionHandle input_channels = c->Dim(input, 1); |
80 | |
81 | DimensionHandle output_length; |
82 | if (!c->ValueKnown(input_length)) { |
83 | output_length = c->UnknownDim(); |
84 | } else { |
85 | const int64_t input_length_value = c->Value(input_length); |
86 | const int64_t length_minus_window = (input_length_value - window_size); |
87 | int64_t output_length_value; |
88 | if (length_minus_window < 0) { |
89 | output_length_value = 0; |
90 | } else { |
91 | output_length_value = 1 + (length_minus_window / stride); |
92 | } |
93 | output_length = c->MakeDim(output_length_value); |
94 | } |
95 | |
96 | DimensionHandle output_channels = |
97 | c->MakeDim(1 + NextPowerOfTwo(window_size) / 2); |
98 | c->set_output(0, |
99 | c->MakeShape({input_channels, output_length, output_channels})); |
100 | return OkStatus(); |
101 | } |
102 | |
103 | Status MfccShapeFn(InferenceContext* c) { |
104 | ShapeHandle spectrogram; |
105 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &spectrogram)); |
106 | ShapeHandle unused; |
107 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
108 | |
109 | int32_t dct_coefficient_count; |
110 | TF_RETURN_IF_ERROR( |
111 | c->GetAttr("dct_coefficient_count" , &dct_coefficient_count)); |
112 | |
113 | DimensionHandle spectrogram_channels = c->Dim(spectrogram, 0); |
114 | DimensionHandle spectrogram_length = c->Dim(spectrogram, 1); |
115 | |
116 | DimensionHandle output_channels = c->MakeDim(dct_coefficient_count); |
117 | |
118 | c->set_output(0, c->MakeShape({spectrogram_channels, spectrogram_length, |
119 | output_channels})); |
120 | return OkStatus(); |
121 | } |
122 | |
123 | } // namespace |
124 | |
125 | REGISTER_OP("DecodeWav" ) |
126 | .Input("contents: string" ) |
127 | .Attr("desired_channels: int = -1" ) |
128 | .Attr("desired_samples: int = -1" ) |
129 | .Output("audio: float" ) |
130 | .Output("sample_rate: int32" ) |
131 | .SetShapeFn(DecodeWavShapeFn); |
132 | |
133 | REGISTER_OP("EncodeWav" ) |
134 | .Input("audio: float" ) |
135 | .Input("sample_rate: int32" ) |
136 | .Output("contents: string" ) |
137 | .SetShapeFn(EncodeWavShapeFn); |
138 | |
139 | REGISTER_OP("AudioSpectrogram" ) |
140 | .Input("input: float" ) |
141 | .Attr("window_size: int" ) |
142 | .Attr("stride: int" ) |
143 | .Attr("magnitude_squared: bool = false" ) |
144 | .Output("spectrogram: float" ) |
145 | .SetShapeFn(SpectrogramShapeFn); |
146 | |
147 | REGISTER_OP("Mfcc" ) |
148 | .Input("spectrogram: float" ) |
149 | .Input("sample_rate: int32" ) |
150 | .Attr("upper_frequency_limit: float = 4000" ) |
151 | .Attr("lower_frequency_limit: float = 20" ) |
152 | .Attr("filterbank_channel_count: int = 40" ) |
153 | .Attr("dct_coefficient_count: int = 13" ) |
154 | .Output("output: float" ) |
155 | .SetShapeFn(MfccShapeFn); |
156 | |
157 | } // namespace tensorflow |
158 | |