1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/graph.pb.h"
17#include "tensorflow/core/framework/op_kernel.h"
18#include "tensorflow/core/framework/resource_mgr.h"
19#include "tensorflow/core/framework/summary.pb.h"
20#include "tensorflow/core/lib/core/refcount.h"
21#include "tensorflow/core/lib/db/sqlite.h"
22#include "tensorflow/core/platform/protobuf.h"
23#include "tensorflow/core/summary/schema.h"
24#include "tensorflow/core/summary/summary_db_writer.h"
25#include "tensorflow/core/summary/summary_file_writer.h"
26#include "tensorflow/core/util/event.pb.h"
27
28namespace tensorflow {
29
30REGISTER_KERNEL_BUILDER(Name("SummaryWriter").Device(DEVICE_CPU),
31 ResourceHandleOp<SummaryWriterInterface>);
32
33class CreateSummaryFileWriterOp : public OpKernel {
34 public:
35 explicit CreateSummaryFileWriterOp(OpKernelConstruction* ctx)
36 : OpKernel(ctx) {}
37
38 void Compute(OpKernelContext* ctx) override {
39 const Tensor* tmp;
40 OP_REQUIRES_OK(ctx, ctx->input("logdir", &tmp));
41 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()),
42 errors::InvalidArgument("logdir must be a scalar"));
43 const string logdir = tmp->scalar<tstring>()();
44 OP_REQUIRES_OK(ctx, ctx->input("max_queue", &tmp));
45 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()),
46 errors::InvalidArgument("max_queue must be a scalar"));
47 const int32_t max_queue = tmp->scalar<int32>()();
48 OP_REQUIRES_OK(ctx, ctx->input("flush_millis", &tmp));
49 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()),
50 errors::InvalidArgument("flush_millis must be a scalar"));
51 const int32_t flush_millis = tmp->scalar<int32>()();
52 OP_REQUIRES_OK(ctx, ctx->input("filename_suffix", &tmp));
53 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()),
54 errors::InvalidArgument("filename_suffix must be a scalar"));
55 const string filename_suffix = tmp->scalar<tstring>()();
56
57 core::RefCountPtr<SummaryWriterInterface> s;
58 OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>(
59 ctx, HandleFromInput(ctx, 0), &s,
60 [max_queue, flush_millis, logdir, filename_suffix,
61 ctx](SummaryWriterInterface** s) {
62 return CreateSummaryFileWriter(
63 max_queue, flush_millis, logdir,
64 filename_suffix, ctx->env(), s);
65 }));
66 }
67};
68REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
69 CreateSummaryFileWriterOp);
70
71class CreateSummaryDbWriterOp : public OpKernel {
72 public:
73 explicit CreateSummaryDbWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
74
75 void Compute(OpKernelContext* ctx) override {
76 const Tensor* tmp;
77 OP_REQUIRES_OK(ctx, ctx->input("db_uri", &tmp));
78 const string db_uri = tmp->scalar<tstring>()();
79 OP_REQUIRES_OK(ctx, ctx->input("experiment_name", &tmp));
80 const string experiment_name = tmp->scalar<tstring>()();
81 OP_REQUIRES_OK(ctx, ctx->input("run_name", &tmp));
82 const string run_name = tmp->scalar<tstring>()();
83 OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
84 const string user_name = tmp->scalar<tstring>()();
85
86 core::RefCountPtr<SummaryWriterInterface> s;
87 OP_REQUIRES_OK(
88 ctx,
89 LookupOrCreateResource<SummaryWriterInterface>(
90 ctx, HandleFromInput(ctx, 0), &s,
91 [db_uri, experiment_name, run_name, user_name,
92 ctx](SummaryWriterInterface** s) {
93 Sqlite* db;
94 TF_RETURN_IF_ERROR(Sqlite::Open(
95 db_uri, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, &db));
96 core::ScopedUnref unref(db);
97 TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db));
98 TF_RETURN_IF_ERROR(CreateSummaryDbWriter(
99 db, experiment_name, run_name, user_name, ctx->env(), s));
100 return OkStatus();
101 }));
102 }
103};
104REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
105 CreateSummaryDbWriterOp);
106
107class FlushSummaryWriterOp : public OpKernel {
108 public:
109 explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
110
111 void Compute(OpKernelContext* ctx) override {
112 core::RefCountPtr<SummaryWriterInterface> s;
113 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
114 OP_REQUIRES_OK(ctx, s->Flush());
115 }
116};
117REGISTER_KERNEL_BUILDER(Name("FlushSummaryWriter").Device(DEVICE_CPU),
118 FlushSummaryWriterOp);
119
120class CloseSummaryWriterOp : public OpKernel {
121 public:
122 explicit CloseSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
123
124 void Compute(OpKernelContext* ctx) override {
125 OP_REQUIRES_OK(ctx, DeleteResource<SummaryWriterInterface>(
126 ctx, HandleFromInput(ctx, 0)));
127 }
128};
129REGISTER_KERNEL_BUILDER(Name("CloseSummaryWriter").Device(DEVICE_CPU),
130 CloseSummaryWriterOp);
131
132class WriteSummaryOp : public OpKernel {
133 public:
134 explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
135
136 void Compute(OpKernelContext* ctx) override {
137 core::RefCountPtr<SummaryWriterInterface> s;
138 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
139 const Tensor* tmp;
140 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
141 const int64_t step = tmp->scalar<int64_t>()();
142 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
143 const string& tag = tmp->scalar<tstring>()();
144 OP_REQUIRES_OK(ctx, ctx->input("summary_metadata", &tmp));
145 const string& serialized_metadata = tmp->scalar<tstring>()();
146
147 const Tensor* t;
148 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
149
150 OP_REQUIRES_OK(ctx, s->WriteTensor(step, *t, tag, serialized_metadata));
151 }
152};
153REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU),
154 WriteSummaryOp);
155
156class WriteRawProtoSummaryOp : public OpKernel {
157 public:
158 explicit WriteRawProtoSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
159
160 void Compute(OpKernelContext* ctx) override {
161 core::RefCountPtr<SummaryWriterInterface> s;
162 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
163 const Tensor* tmp;
164 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
165 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()),
166 errors::InvalidArgument("step must be scalar, got shape ",
167 tmp->shape().DebugString()));
168 const int64_t step = tmp->scalar<int64_t>()();
169 const Tensor* t;
170 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
171 std::unique_ptr<Event> event{new Event};
172 event->set_step(step);
173 event->set_wall_time(static_cast<double>(ctx->env()->NowMicros()) / 1.0e6);
174 // Each Summary proto contains just one repeated field "value" of Value
175 // messages with the actual data, so repeated Merge() is equivalent to
176 // concatenating all the Value entries together into a single Event.
177 const auto summary_pbs = t->flat<tstring>();
178 for (int i = 0; i < summary_pbs.size(); ++i) {
179 if (!event->mutable_summary()->MergeFromString(summary_pbs(i))) {
180 ctx->CtxFailureWithWarning(errors::DataLoss(
181 "Bad tf.compat.v1.Summary binary proto tensor string at index ",
182 i));
183 return;
184 }
185 }
186 OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event)));
187 }
188};
189REGISTER_KERNEL_BUILDER(Name("WriteRawProtoSummary").Device(DEVICE_CPU),
190 WriteRawProtoSummaryOp);
191
192class ImportEventOp : public OpKernel {
193 public:
194 explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
195
196 void Compute(OpKernelContext* ctx) override {
197 core::RefCountPtr<SummaryWriterInterface> s;
198 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
199 const Tensor* t;
200 OP_REQUIRES_OK(ctx, ctx->input("event", &t));
201 std::unique_ptr<Event> event{new Event};
202 if (!ParseProtoUnlimited(event.get(), t->scalar<tstring>()())) {
203 ctx->CtxFailureWithWarning(
204 errors::DataLoss("Bad tf.Event binary proto tensor string"));
205 return;
206 }
207 OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event)));
208 }
209};
210REGISTER_KERNEL_BUILDER(Name("ImportEvent").Device(DEVICE_CPU), ImportEventOp);
211
212class WriteScalarSummaryOp : public OpKernel {
213 public:
214 explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
215
216 void Compute(OpKernelContext* ctx) override {
217 core::RefCountPtr<SummaryWriterInterface> s;
218 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
219 const Tensor* tmp;
220 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
221 const int64_t step = tmp->scalar<int64_t>()();
222 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
223 const string& tag = tmp->scalar<tstring>()();
224
225 const Tensor* t;
226 OP_REQUIRES_OK(ctx, ctx->input("value", &t));
227
228 OP_REQUIRES_OK(ctx, s->WriteScalar(step, *t, tag));
229 }
230};
231REGISTER_KERNEL_BUILDER(Name("WriteScalarSummary").Device(DEVICE_CPU),
232 WriteScalarSummaryOp);
233
234class WriteHistogramSummaryOp : public OpKernel {
235 public:
236 explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
237
238 void Compute(OpKernelContext* ctx) override {
239 core::RefCountPtr<SummaryWriterInterface> s;
240 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
241 const Tensor* tmp;
242 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
243 const int64_t step = tmp->scalar<int64_t>()();
244 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
245 const string& tag = tmp->scalar<tstring>()();
246
247 const Tensor* t;
248 OP_REQUIRES_OK(ctx, ctx->input("values", &t));
249
250 OP_REQUIRES_OK(ctx, s->WriteHistogram(step, *t, tag));
251 }
252};
253REGISTER_KERNEL_BUILDER(Name("WriteHistogramSummary").Device(DEVICE_CPU),
254 WriteHistogramSummaryOp);
255
256class WriteImageSummaryOp : public OpKernel {
257 public:
258 explicit WriteImageSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
259 int64_t max_images_tmp;
260 OP_REQUIRES_OK(ctx, ctx->GetAttr("max_images", &max_images_tmp));
261 OP_REQUIRES(ctx, max_images_tmp < (1LL << 31),
262 errors::InvalidArgument("max_images must be < 2^31"));
263 max_images_ = static_cast<int32>(max_images_tmp);
264 }
265
266 void Compute(OpKernelContext* ctx) override {
267 core::RefCountPtr<SummaryWriterInterface> s;
268 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
269 const Tensor* tmp;
270 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
271 const int64_t step = tmp->scalar<int64_t>()();
272 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
273 const string& tag = tmp->scalar<tstring>()();
274 const Tensor* bad_color;
275 OP_REQUIRES_OK(ctx, ctx->input("bad_color", &bad_color));
276 OP_REQUIRES(
277 ctx, TensorShapeUtils::IsVector(bad_color->shape()),
278 errors::InvalidArgument("bad_color must be a vector, got shape ",
279 bad_color->shape().DebugString()));
280
281 const Tensor* t;
282 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
283
284 OP_REQUIRES_OK(ctx, s->WriteImage(step, *t, tag, max_images_, *bad_color));
285 }
286
287 private:
288 int32 max_images_;
289};
290REGISTER_KERNEL_BUILDER(Name("WriteImageSummary").Device(DEVICE_CPU),
291 WriteImageSummaryOp);
292
293class WriteAudioSummaryOp : public OpKernel {
294 public:
295 explicit WriteAudioSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
296 OP_REQUIRES_OK(ctx, ctx->GetAttr("max_outputs", &max_outputs_));
297 OP_REQUIRES(ctx, max_outputs_ > 0,
298 errors::InvalidArgument("max_outputs must be > 0"));
299 }
300
301 void Compute(OpKernelContext* ctx) override {
302 core::RefCountPtr<SummaryWriterInterface> s;
303 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
304 const Tensor* tmp;
305 OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
306 const int64_t step = tmp->scalar<int64_t>()();
307 OP_REQUIRES_OK(ctx, ctx->input("tag", &tmp));
308 const string& tag = tmp->scalar<tstring>()();
309 OP_REQUIRES_OK(ctx, ctx->input("sample_rate", &tmp));
310 const float sample_rate = tmp->scalar<float>()();
311
312 const Tensor* t;
313 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
314
315 OP_REQUIRES_OK(ctx,
316 s->WriteAudio(step, *t, tag, max_outputs_, sample_rate));
317 }
318
319 private:
320 int max_outputs_;
321};
322REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary").Device(DEVICE_CPU),
323 WriteAudioSummaryOp);
324
325class WriteGraphSummaryOp : public OpKernel {
326 public:
327 explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
328
329 void Compute(OpKernelContext* ctx) override {
330 core::RefCountPtr<SummaryWriterInterface> s;
331 OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
332 const Tensor* t;
333 OP_REQUIRES_OK(ctx, ctx->input("step", &t));
334 const int64_t step = t->scalar<int64_t>()();
335 OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
336 std::unique_ptr<GraphDef> graph{new GraphDef};
337 if (!ParseProtoUnlimited(graph.get(), t->scalar<tstring>()())) {
338 ctx->CtxFailureWithWarning(
339 errors::DataLoss("Bad tf.GraphDef binary proto tensor string"));
340 return;
341 }
342 OP_REQUIRES_OK(ctx, s->WriteGraph(step, std::move(graph)));
343 }
344};
345REGISTER_KERNEL_BUILDER(Name("WriteGraphSummary").Device(DEVICE_CPU),
346 WriteGraphSummaryOp);
347
348} // namespace tensorflow
349