1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/graph.pb.h" |
17 | #include "tensorflow/core/framework/op_kernel.h" |
18 | #include "tensorflow/core/framework/resource_mgr.h" |
19 | #include "tensorflow/core/framework/summary.pb.h" |
20 | #include "tensorflow/core/lib/core/refcount.h" |
21 | #include "tensorflow/core/lib/db/sqlite.h" |
22 | #include "tensorflow/core/platform/protobuf.h" |
23 | #include "tensorflow/core/summary/schema.h" |
24 | #include "tensorflow/core/summary/summary_db_writer.h" |
25 | #include "tensorflow/core/summary/summary_file_writer.h" |
26 | #include "tensorflow/core/util/event.pb.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | REGISTER_KERNEL_BUILDER(Name("SummaryWriter" ).Device(DEVICE_CPU), |
31 | ResourceHandleOp<SummaryWriterInterface>); |
32 | |
33 | class CreateSummaryFileWriterOp : public OpKernel { |
34 | public: |
35 | explicit CreateSummaryFileWriterOp(OpKernelConstruction* ctx) |
36 | : OpKernel(ctx) {} |
37 | |
38 | void Compute(OpKernelContext* ctx) override { |
39 | const Tensor* tmp; |
40 | OP_REQUIRES_OK(ctx, ctx->input("logdir" , &tmp)); |
41 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()), |
42 | errors::InvalidArgument("logdir must be a scalar" )); |
43 | const string logdir = tmp->scalar<tstring>()(); |
44 | OP_REQUIRES_OK(ctx, ctx->input("max_queue" , &tmp)); |
45 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()), |
46 | errors::InvalidArgument("max_queue must be a scalar" )); |
47 | const int32_t max_queue = tmp->scalar<int32>()(); |
48 | OP_REQUIRES_OK(ctx, ctx->input("flush_millis" , &tmp)); |
49 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()), |
50 | errors::InvalidArgument("flush_millis must be a scalar" )); |
51 | const int32_t flush_millis = tmp->scalar<int32>()(); |
52 | OP_REQUIRES_OK(ctx, ctx->input("filename_suffix" , &tmp)); |
53 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()), |
54 | errors::InvalidArgument("filename_suffix must be a scalar" )); |
55 | const string filename_suffix = tmp->scalar<tstring>()(); |
56 | |
57 | core::RefCountPtr<SummaryWriterInterface> s; |
58 | OP_REQUIRES_OK(ctx, LookupOrCreateResource<SummaryWriterInterface>( |
59 | ctx, HandleFromInput(ctx, 0), &s, |
60 | [max_queue, flush_millis, logdir, filename_suffix, |
61 | ctx](SummaryWriterInterface** s) { |
62 | return CreateSummaryFileWriter( |
63 | max_queue, flush_millis, logdir, |
64 | filename_suffix, ctx->env(), s); |
65 | })); |
66 | } |
67 | }; |
68 | REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter" ).Device(DEVICE_CPU), |
69 | CreateSummaryFileWriterOp); |
70 | |
71 | class CreateSummaryDbWriterOp : public OpKernel { |
72 | public: |
73 | explicit CreateSummaryDbWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
74 | |
75 | void Compute(OpKernelContext* ctx) override { |
76 | const Tensor* tmp; |
77 | OP_REQUIRES_OK(ctx, ctx->input("db_uri" , &tmp)); |
78 | const string db_uri = tmp->scalar<tstring>()(); |
79 | OP_REQUIRES_OK(ctx, ctx->input("experiment_name" , &tmp)); |
80 | const string experiment_name = tmp->scalar<tstring>()(); |
81 | OP_REQUIRES_OK(ctx, ctx->input("run_name" , &tmp)); |
82 | const string run_name = tmp->scalar<tstring>()(); |
83 | OP_REQUIRES_OK(ctx, ctx->input("user_name" , &tmp)); |
84 | const string user_name = tmp->scalar<tstring>()(); |
85 | |
86 | core::RefCountPtr<SummaryWriterInterface> s; |
87 | OP_REQUIRES_OK( |
88 | ctx, |
89 | LookupOrCreateResource<SummaryWriterInterface>( |
90 | ctx, HandleFromInput(ctx, 0), &s, |
91 | [db_uri, experiment_name, run_name, user_name, |
92 | ctx](SummaryWriterInterface** s) { |
93 | Sqlite* db; |
94 | TF_RETURN_IF_ERROR(Sqlite::Open( |
95 | db_uri, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE, &db)); |
96 | core::ScopedUnref unref(db); |
97 | TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db)); |
98 | TF_RETURN_IF_ERROR(CreateSummaryDbWriter( |
99 | db, experiment_name, run_name, user_name, ctx->env(), s)); |
100 | return OkStatus(); |
101 | })); |
102 | } |
103 | }; |
104 | REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter" ).Device(DEVICE_CPU), |
105 | CreateSummaryDbWriterOp); |
106 | |
107 | class FlushSummaryWriterOp : public OpKernel { |
108 | public: |
109 | explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
110 | |
111 | void Compute(OpKernelContext* ctx) override { |
112 | core::RefCountPtr<SummaryWriterInterface> s; |
113 | OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); |
114 | OP_REQUIRES_OK(ctx, s->Flush()); |
115 | } |
116 | }; |
117 | REGISTER_KERNEL_BUILDER(Name("FlushSummaryWriter" ).Device(DEVICE_CPU), |
118 | FlushSummaryWriterOp); |
119 | |
120 | class CloseSummaryWriterOp : public OpKernel { |
121 | public: |
122 | explicit CloseSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
123 | |
124 | void Compute(OpKernelContext* ctx) override { |
125 | OP_REQUIRES_OK(ctx, DeleteResource<SummaryWriterInterface>( |
126 | ctx, HandleFromInput(ctx, 0))); |
127 | } |
128 | }; |
129 | REGISTER_KERNEL_BUILDER(Name("CloseSummaryWriter" ).Device(DEVICE_CPU), |
130 | CloseSummaryWriterOp); |
131 | |
132 | class WriteSummaryOp : public OpKernel { |
133 | public: |
134 | explicit WriteSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
135 | |
136 | void Compute(OpKernelContext* ctx) override { |
137 | core::RefCountPtr<SummaryWriterInterface> s; |
138 | OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); |
139 | const Tensor* tmp; |
140 | OP_REQUIRES_OK(ctx, ctx->input("step" , &tmp)); |
141 | const int64_t step = tmp->scalar<int64_t>()(); |
142 | OP_REQUIRES_OK(ctx, ctx->input("tag" , &tmp)); |
143 | const string& tag = tmp->scalar<tstring>()(); |
144 | OP_REQUIRES_OK(ctx, ctx->input("summary_metadata" , &tmp)); |
145 | const string& serialized_metadata = tmp->scalar<tstring>()(); |
146 | |
147 | const Tensor* t; |
148 | OP_REQUIRES_OK(ctx, ctx->input("tensor" , &t)); |
149 | |
150 | OP_REQUIRES_OK(ctx, s->WriteTensor(step, *t, tag, serialized_metadata)); |
151 | } |
152 | }; |
153 | REGISTER_KERNEL_BUILDER(Name("WriteSummary" ).Device(DEVICE_CPU), |
154 | WriteSummaryOp); |
155 | |
156 | class WriteRawProtoSummaryOp : public OpKernel { |
157 | public: |
158 | explicit WriteRawProtoSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
159 | |
160 | void Compute(OpKernelContext* ctx) override { |
161 | core::RefCountPtr<SummaryWriterInterface> s; |
162 | OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); |
163 | const Tensor* tmp; |
164 | OP_REQUIRES_OK(ctx, ctx->input("step" , &tmp)); |
165 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()), |
166 | errors::InvalidArgument("step must be scalar, got shape " , |
167 | tmp->shape().DebugString())); |
168 | const int64_t step = tmp->scalar<int64_t>()(); |
169 | const Tensor* t; |
170 | OP_REQUIRES_OK(ctx, ctx->input("tensor" , &t)); |
171 | std::unique_ptr<Event> event{new Event}; |
172 | event->set_step(step); |
173 | event->set_wall_time(static_cast<double>(ctx->env()->NowMicros()) / 1.0e6); |
174 | // Each Summary proto contains just one repeated field "value" of Value |
175 | // messages with the actual data, so repeated Merge() is equivalent to |
176 | // concatenating all the Value entries together into a single Event. |
177 | const auto summary_pbs = t->flat<tstring>(); |
178 | for (int i = 0; i < summary_pbs.size(); ++i) { |
179 | if (!event->mutable_summary()->MergeFromString(summary_pbs(i))) { |
180 | ctx->CtxFailureWithWarning(errors::DataLoss( |
181 | "Bad tf.compat.v1.Summary binary proto tensor string at index " , |
182 | i)); |
183 | return; |
184 | } |
185 | } |
186 | OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event))); |
187 | } |
188 | }; |
189 | REGISTER_KERNEL_BUILDER(Name("WriteRawProtoSummary" ).Device(DEVICE_CPU), |
190 | WriteRawProtoSummaryOp); |
191 | |
192 | class ImportEventOp : public OpKernel { |
193 | public: |
194 | explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
195 | |
196 | void Compute(OpKernelContext* ctx) override { |
197 | core::RefCountPtr<SummaryWriterInterface> s; |
198 | OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); |
199 | const Tensor* t; |
200 | OP_REQUIRES_OK(ctx, ctx->input("event" , &t)); |
201 | std::unique_ptr<Event> event{new Event}; |
202 | if (!ParseProtoUnlimited(event.get(), t->scalar<tstring>()())) { |
203 | ctx->CtxFailureWithWarning( |
204 | errors::DataLoss("Bad tf.Event binary proto tensor string" )); |
205 | return; |
206 | } |
207 | OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event))); |
208 | } |
209 | }; |
210 | REGISTER_KERNEL_BUILDER(Name("ImportEvent" ).Device(DEVICE_CPU), ImportEventOp); |
211 | |
212 | class WriteScalarSummaryOp : public OpKernel { |
213 | public: |
214 | explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
215 | |
216 | void Compute(OpKernelContext* ctx) override { |
217 | core::RefCountPtr<SummaryWriterInterface> s; |
218 | OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); |
219 | const Tensor* tmp; |
220 | OP_REQUIRES_OK(ctx, ctx->input("step" , &tmp)); |
221 | const int64_t step = tmp->scalar<int64_t>()(); |
222 | OP_REQUIRES_OK(ctx, ctx->input("tag" , &tmp)); |
223 | const string& tag = tmp->scalar<tstring>()(); |
224 | |
225 | const Tensor* t; |
226 | OP_REQUIRES_OK(ctx, ctx->input("value" , &t)); |
227 | |
228 | OP_REQUIRES_OK(ctx, s->WriteScalar(step, *t, tag)); |
229 | } |
230 | }; |
231 | REGISTER_KERNEL_BUILDER(Name("WriteScalarSummary" ).Device(DEVICE_CPU), |
232 | WriteScalarSummaryOp); |
233 | |
234 | class WriteHistogramSummaryOp : public OpKernel { |
235 | public: |
236 | explicit WriteHistogramSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
237 | |
238 | void Compute(OpKernelContext* ctx) override { |
239 | core::RefCountPtr<SummaryWriterInterface> s; |
240 | OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); |
241 | const Tensor* tmp; |
242 | OP_REQUIRES_OK(ctx, ctx->input("step" , &tmp)); |
243 | const int64_t step = tmp->scalar<int64_t>()(); |
244 | OP_REQUIRES_OK(ctx, ctx->input("tag" , &tmp)); |
245 | const string& tag = tmp->scalar<tstring>()(); |
246 | |
247 | const Tensor* t; |
248 | OP_REQUIRES_OK(ctx, ctx->input("values" , &t)); |
249 | |
250 | OP_REQUIRES_OK(ctx, s->WriteHistogram(step, *t, tag)); |
251 | } |
252 | }; |
253 | REGISTER_KERNEL_BUILDER(Name("WriteHistogramSummary" ).Device(DEVICE_CPU), |
254 | WriteHistogramSummaryOp); |
255 | |
256 | class WriteImageSummaryOp : public OpKernel { |
257 | public: |
258 | explicit WriteImageSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
259 | int64_t max_images_tmp; |
260 | OP_REQUIRES_OK(ctx, ctx->GetAttr("max_images" , &max_images_tmp)); |
261 | OP_REQUIRES(ctx, max_images_tmp < (1LL << 31), |
262 | errors::InvalidArgument("max_images must be < 2^31" )); |
263 | max_images_ = static_cast<int32>(max_images_tmp); |
264 | } |
265 | |
266 | void Compute(OpKernelContext* ctx) override { |
267 | core::RefCountPtr<SummaryWriterInterface> s; |
268 | OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); |
269 | const Tensor* tmp; |
270 | OP_REQUIRES_OK(ctx, ctx->input("step" , &tmp)); |
271 | const int64_t step = tmp->scalar<int64_t>()(); |
272 | OP_REQUIRES_OK(ctx, ctx->input("tag" , &tmp)); |
273 | const string& tag = tmp->scalar<tstring>()(); |
274 | const Tensor* bad_color; |
275 | OP_REQUIRES_OK(ctx, ctx->input("bad_color" , &bad_color)); |
276 | OP_REQUIRES( |
277 | ctx, TensorShapeUtils::IsVector(bad_color->shape()), |
278 | errors::InvalidArgument("bad_color must be a vector, got shape " , |
279 | bad_color->shape().DebugString())); |
280 | |
281 | const Tensor* t; |
282 | OP_REQUIRES_OK(ctx, ctx->input("tensor" , &t)); |
283 | |
284 | OP_REQUIRES_OK(ctx, s->WriteImage(step, *t, tag, max_images_, *bad_color)); |
285 | } |
286 | |
287 | private: |
288 | int32 max_images_; |
289 | }; |
290 | REGISTER_KERNEL_BUILDER(Name("WriteImageSummary" ).Device(DEVICE_CPU), |
291 | WriteImageSummaryOp); |
292 | |
293 | class WriteAudioSummaryOp : public OpKernel { |
294 | public: |
295 | explicit WriteAudioSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
296 | OP_REQUIRES_OK(ctx, ctx->GetAttr("max_outputs" , &max_outputs_)); |
297 | OP_REQUIRES(ctx, max_outputs_ > 0, |
298 | errors::InvalidArgument("max_outputs must be > 0" )); |
299 | } |
300 | |
301 | void Compute(OpKernelContext* ctx) override { |
302 | core::RefCountPtr<SummaryWriterInterface> s; |
303 | OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); |
304 | const Tensor* tmp; |
305 | OP_REQUIRES_OK(ctx, ctx->input("step" , &tmp)); |
306 | const int64_t step = tmp->scalar<int64_t>()(); |
307 | OP_REQUIRES_OK(ctx, ctx->input("tag" , &tmp)); |
308 | const string& tag = tmp->scalar<tstring>()(); |
309 | OP_REQUIRES_OK(ctx, ctx->input("sample_rate" , &tmp)); |
310 | const float sample_rate = tmp->scalar<float>()(); |
311 | |
312 | const Tensor* t; |
313 | OP_REQUIRES_OK(ctx, ctx->input("tensor" , &t)); |
314 | |
315 | OP_REQUIRES_OK(ctx, |
316 | s->WriteAudio(step, *t, tag, max_outputs_, sample_rate)); |
317 | } |
318 | |
319 | private: |
320 | int max_outputs_; |
321 | }; |
322 | REGISTER_KERNEL_BUILDER(Name("WriteAudioSummary" ).Device(DEVICE_CPU), |
323 | WriteAudioSummaryOp); |
324 | |
325 | class WriteGraphSummaryOp : public OpKernel { |
326 | public: |
327 | explicit WriteGraphSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
328 | |
329 | void Compute(OpKernelContext* ctx) override { |
330 | core::RefCountPtr<SummaryWriterInterface> s; |
331 | OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s)); |
332 | const Tensor* t; |
333 | OP_REQUIRES_OK(ctx, ctx->input("step" , &t)); |
334 | const int64_t step = t->scalar<int64_t>()(); |
335 | OP_REQUIRES_OK(ctx, ctx->input("tensor" , &t)); |
336 | std::unique_ptr<GraphDef> graph{new GraphDef}; |
337 | if (!ParseProtoUnlimited(graph.get(), t->scalar<tstring>()())) { |
338 | ctx->CtxFailureWithWarning( |
339 | errors::DataLoss("Bad tf.GraphDef binary proto tensor string" )); |
340 | return; |
341 | } |
342 | OP_REQUIRES_OK(ctx, s->WriteGraph(step, std::move(graph))); |
343 | } |
344 | }; |
345 | REGISTER_KERNEL_BUILDER(Name("WriteGraphSummary" ).Device(DEVICE_CPU), |
346 | WriteGraphSummaryOp); |
347 | |
348 | } // namespace tensorflow |
349 | |