1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
17// inputs or outputs in various ways.
18
19// See docs in ../ops/summary_ops.cc.
20
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/framework/summary.pb.h"
23#include "tensorflow/core/lib/core/errors.h"
24#include "tensorflow/core/lib/wav/wav_io.h"
25#include "tensorflow/core/platform/types.h"
26
27namespace tensorflow {
28
29class SummaryAudioOp : public OpKernel {
30 public:
31 explicit SummaryAudioOp(OpKernelConstruction* context) : OpKernel(context) {
32 OP_REQUIRES_OK(context, context->GetAttr("max_outputs", &max_outputs_));
33 OP_REQUIRES(context, max_outputs_ > 0,
34 errors::InvalidArgument("max_outputs must be > 0"));
35 has_sample_rate_attr_ =
36 context->GetAttr("sample_rate", &sample_rate_attr_).ok();
37 }
38
39 void Compute(OpKernelContext* c) override {
40 const Tensor& tag = c->input(0);
41 const Tensor& tensor = c->input(1);
42 OP_REQUIRES(c, TensorShapeUtils::IsScalar(tag.shape()),
43 errors::InvalidArgument("Tag must be a scalar"));
44 OP_REQUIRES(c, tensor.dims() >= 2 && tensor.dims() <= 3,
45 errors::InvalidArgument("Tensor must be 3-D or 2-D, got: ",
46 tensor.shape().DebugString()));
47 const string& base_tag = tag.scalar<tstring>()();
48
49 float sample_rate = sample_rate_attr_;
50 if (!has_sample_rate_attr_) {
51 const Tensor& sample_rate_tensor = c->input(2);
52 OP_REQUIRES(c,
53 sample_rate_tensor.IsAligned() &&
54 sample_rate_tensor.NumElements() == 1,
55 errors::InvalidArgument(
56 "sample_rate must be rank-0 or contain a single value"));
57 sample_rate = sample_rate_tensor.scalar<float>()();
58 }
59 OP_REQUIRES(c, sample_rate > 0.0f,
60 errors::InvalidArgument("sample_rate must be > 0"));
61
62 const int batch_size = tensor.dim_size(0);
63 const int64_t length_frames = tensor.dim_size(1);
64 const int64_t num_channels =
65 tensor.dims() == 2 ? 1 : tensor.dim_size(tensor.dims() - 1);
66
67 Summary s;
68 const int N = std::min<int>(max_outputs_, batch_size);
69 for (int i = 0; i < N; ++i) {
70 Summary::Value* v = s.add_value();
71 if (max_outputs_ > 1) {
72 v->set_tag(strings::StrCat(base_tag, "/audio/", i));
73 } else {
74 v->set_tag(strings::StrCat(base_tag, "/audio"));
75 }
76
77 Summary::Audio* sa = v->mutable_audio();
78 sa->set_sample_rate(sample_rate);
79 sa->set_num_channels(num_channels);
80 sa->set_length_frames(length_frames);
81 sa->set_content_type("audio/wav");
82
83 auto values =
84 tensor.shaped<float, 3>({batch_size, length_frames, num_channels});
85 const float* data =
86 tensor.NumElements() == 0 ? nullptr : &values(i, 0, 0);
87
88 size_t sample_rate_truncated = lrintf(sample_rate);
89 if (sample_rate_truncated == 0) {
90 sample_rate_truncated = 1;
91 }
92 OP_REQUIRES_OK(c, wav::EncodeAudioAsS16LEWav(
93 data, sample_rate_truncated, num_channels,
94 length_frames, sa->mutable_encoded_audio_string()));
95 }
96
97 Tensor* summary_tensor = nullptr;
98 OP_REQUIRES_OK(c, c->allocate_output(0, TensorShape({}), &summary_tensor));
99 CHECK(SerializeToTString(s, &summary_tensor->scalar<tstring>()()));
100 }
101
102 private:
103 int max_outputs_;
104 bool has_sample_rate_attr_;
105 float sample_rate_attr_;
106};
107
108REGISTER_KERNEL_BUILDER(Name("AudioSummaryV2").Device(DEVICE_CPU),
109 SummaryAudioOp);
110
111// Deprecated -- this op is registered with sample_rate as an attribute for
112// backwards compatibility.
113REGISTER_KERNEL_BUILDER(Name("AudioSummary").Device(DEVICE_CPU),
114 SummaryAudioOp);
115
116} // namespace tensorflow
117