1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as |
17 | // inputs or outputs in various ways. |
18 | |
19 | // See docs in ../ops/summary_ops.cc. |
20 | |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/summary.pb.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | #include "tensorflow/core/lib/wav/wav_io.h" |
25 | #include "tensorflow/core/platform/types.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | class SummaryAudioOp : public OpKernel { |
30 | public: |
31 | explicit SummaryAudioOp(OpKernelConstruction* context) : OpKernel(context) { |
32 | OP_REQUIRES_OK(context, context->GetAttr("max_outputs" , &max_outputs_)); |
33 | OP_REQUIRES(context, max_outputs_ > 0, |
34 | errors::InvalidArgument("max_outputs must be > 0" )); |
35 | has_sample_rate_attr_ = |
36 | context->GetAttr("sample_rate" , &sample_rate_attr_).ok(); |
37 | } |
38 | |
39 | void Compute(OpKernelContext* c) override { |
40 | const Tensor& tag = c->input(0); |
41 | const Tensor& tensor = c->input(1); |
42 | OP_REQUIRES(c, TensorShapeUtils::IsScalar(tag.shape()), |
43 | errors::InvalidArgument("Tag must be a scalar" )); |
44 | OP_REQUIRES(c, tensor.dims() >= 2 && tensor.dims() <= 3, |
45 | errors::InvalidArgument("Tensor must be 3-D or 2-D, got: " , |
46 | tensor.shape().DebugString())); |
47 | const string& base_tag = tag.scalar<tstring>()(); |
48 | |
49 | float sample_rate = sample_rate_attr_; |
50 | if (!has_sample_rate_attr_) { |
51 | const Tensor& sample_rate_tensor = c->input(2); |
52 | OP_REQUIRES(c, |
53 | sample_rate_tensor.IsAligned() && |
54 | sample_rate_tensor.NumElements() == 1, |
55 | errors::InvalidArgument( |
56 | "sample_rate must be rank-0 or contain a single value" )); |
57 | sample_rate = sample_rate_tensor.scalar<float>()(); |
58 | } |
59 | OP_REQUIRES(c, sample_rate > 0.0f, |
60 | errors::InvalidArgument("sample_rate must be > 0" )); |
61 | |
62 | const int batch_size = tensor.dim_size(0); |
63 | const int64_t length_frames = tensor.dim_size(1); |
64 | const int64_t num_channels = |
65 | tensor.dims() == 2 ? 1 : tensor.dim_size(tensor.dims() - 1); |
66 | |
67 | Summary s; |
68 | const int N = std::min<int>(max_outputs_, batch_size); |
69 | for (int i = 0; i < N; ++i) { |
70 | Summary::Value* v = s.add_value(); |
71 | if (max_outputs_ > 1) { |
72 | v->set_tag(strings::StrCat(base_tag, "/audio/" , i)); |
73 | } else { |
74 | v->set_tag(strings::StrCat(base_tag, "/audio" )); |
75 | } |
76 | |
77 | Summary::Audio* sa = v->mutable_audio(); |
78 | sa->set_sample_rate(sample_rate); |
79 | sa->set_num_channels(num_channels); |
80 | sa->set_length_frames(length_frames); |
81 | sa->set_content_type("audio/wav" ); |
82 | |
83 | auto values = |
84 | tensor.shaped<float, 3>({batch_size, length_frames, num_channels}); |
85 | const float* data = |
86 | tensor.NumElements() == 0 ? nullptr : &values(i, 0, 0); |
87 | |
88 | size_t sample_rate_truncated = lrintf(sample_rate); |
89 | if (sample_rate_truncated == 0) { |
90 | sample_rate_truncated = 1; |
91 | } |
92 | OP_REQUIRES_OK(c, wav::EncodeAudioAsS16LEWav( |
93 | data, sample_rate_truncated, num_channels, |
94 | length_frames, sa->mutable_encoded_audio_string())); |
95 | } |
96 | |
97 | Tensor* summary_tensor = nullptr; |
98 | OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor)); |
99 | CHECK(SerializeToTString(s, &summary_tensor->scalar<tstring>()())); |
100 | } |
101 | |
102 | private: |
103 | int max_outputs_; |
104 | bool has_sample_rate_attr_; |
105 | float sample_rate_attr_; |
106 | }; |
107 | |
108 | REGISTER_KERNEL_BUILDER(Name("AudioSummaryV2" ).Device(DEVICE_CPU), |
109 | SummaryAudioOp); |
110 | |
111 | // Deprecated -- this op is registered with sample_rate as an attribute for |
112 | // backwards compatibility. |
113 | REGISTER_KERNEL_BUILDER(Name("AudioSummary" ).Device(DEVICE_CPU), |
114 | SummaryAudioOp); |
115 | |
116 | } // namespace tensorflow |
117 | |