1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/summary/summary_db_writer.h"
16
17#include <deque>
18
19#include "tensorflow/core/summary/summary_converter.h"
20#include "tensorflow/core/framework/graph.pb.h"
21#include "tensorflow/core/framework/node_def.pb.h"
22#include "tensorflow/core/framework/register_types.h"
23#include "tensorflow/core/framework/summary.pb.h"
24#include "tensorflow/core/lib/core/stringpiece.h"
25#include "tensorflow/core/lib/db/sqlite.h"
26#include "tensorflow/core/lib/random/random.h"
27#include "tensorflow/core/util/event.pb.h"
28
29// TODO(jart): Break this up into multiple files with excellent unit tests.
30// TODO(jart): Make decision to write in separate op.
31// TODO(jart): Add really good busy handling.
32
33// clang-format off
34#define CALL_SUPPORTED_TYPES(m) \
35 TF_CALL_tstring(m) \
36 TF_CALL_half(m) \
37 TF_CALL_float(m) \
38 TF_CALL_double(m) \
39 TF_CALL_complex64(m) \
40 TF_CALL_complex128(m) \
41 TF_CALL_int8(m) \
42 TF_CALL_int16(m) \
43 TF_CALL_int32(m) \
44 TF_CALL_int64(m) \
45 TF_CALL_uint8(m) \
46 TF_CALL_uint16(m) \
47 TF_CALL_uint32(m) \
48 TF_CALL_uint64(m)
49// clang-format on
50
51namespace tensorflow {
52namespace {
53
54// https://www.sqlite.org/fileformat.html#record_format
55const uint64 kIdTiers[] = {
56 0x7fffffULL, // 23-bit (3 bytes on disk)
57 0x7fffffffULL, // 31-bit (4 bytes on disk)
58 0x7fffffffffffULL, // 47-bit (5 bytes on disk)
59 // remaining bits for future use
60};
61const int kMaxIdTier = sizeof(kIdTiers) / sizeof(uint64);
62const int kIdCollisionDelayMicros = 10;
63const int kMaxIdCollisions = 21; // sum(2**i*10µs for i in range(21))~=21s
64const int64_t kAbsent = 0LL;
65
66const char* kScalarPluginName = "scalars";
67const char* kImagePluginName = "images";
68const char* kAudioPluginName = "audio";
69const char* kHistogramPluginName = "histograms";
70
71const int64_t kReserveMinBytes = 32;
72const double kReserveMultiplier = 1.5;
73const int64_t kPreallocateRows = 1000;
74
75// Flush is a misnomer because what we're actually doing is having lots
76// of commits inside any SqliteTransaction that writes potentially
77// hundreds of megs but doesn't need the transaction to maintain its
78// invariants. This ensures the WAL read penalty is small and might
79// allow writers in other processes a chance to schedule.
80const uint64 kFlushBytes = 1024 * 1024;
81
82double DoubleTime(uint64 micros) {
83 // TODO(@jart): Follow precise definitions for time laid out in schema.
84 // TODO(@jart): Use monotonic clock from gRPC codebase.
85 return static_cast<double>(micros) / 1.0e6;
86}
87
88string StringifyShape(const TensorShape& shape) {
89 string result;
90 bool first = true;
91 for (const auto& dim : shape) {
92 if (first) {
93 first = false;
94 } else {
95 strings::StrAppend(&result, ",");
96 }
97 strings::StrAppend(&result, dim.size);
98 }
99 return result;
100}
101
102Status CheckSupportedType(const Tensor& t) {
103#define CASE(T) \
104 case DataTypeToEnum<T>::value: \
105 break;
106 switch (t.dtype()) {
107 CALL_SUPPORTED_TYPES(CASE)
108 default:
109 return errors::Unimplemented(DataTypeString(t.dtype()),
110 " tensors unsupported on platform");
111 }
112 return OkStatus();
113#undef CASE
114}
115
116Tensor AsScalar(const Tensor& t) {
117 Tensor t2{t.dtype(), {}};
118#define CASE(T) \
119 case DataTypeToEnum<T>::value: \
120 t2.scalar<T>()() = t.flat<T>()(0); \
121 break;
122 switch (t.dtype()) {
123 CALL_SUPPORTED_TYPES(CASE)
124 default:
125 t2 = {DT_FLOAT, {}};
126 t2.scalar<float>()() = NAN;
127 break;
128 }
129 return t2;
130#undef CASE
131}
132
133void PatchPluginName(SummaryMetadata* metadata, const char* name) {
134 if (metadata->plugin_data().plugin_name().empty()) {
135 metadata->mutable_plugin_data()->set_plugin_name(name);
136 }
137}
138
139Status SetDescription(Sqlite* db, int64_t id, const StringPiece& markdown) {
140 const char* sql = R"sql(
141 INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?)
142 )sql";
143 SqliteStatement insert_desc;
144 TF_RETURN_IF_ERROR(db->Prepare(sql, &insert_desc));
145 insert_desc.BindInt(1, id);
146 insert_desc.BindText(2, markdown);
147 return insert_desc.StepAndReset();
148}
149
150/// \brief Generates unique IDs randomly in the [1,2**63-1] range.
151///
152/// This class starts off generating IDs in the [1,2**23-1] range,
153/// because it's human friendly and occupies 4 bytes max on disk with
154/// SQLite's zigzag varint encoding. Then, each time a collision
155/// happens, the random space is increased by 8 bits.
156///
157/// This class uses exponential back-off so writes gradually slow down
158/// as IDs become exhausted but reads are still possible.
159///
160/// This class is thread safe.
161class IdAllocator {
162 public:
163 IdAllocator(Env* env, Sqlite* db) : env_{env}, db_{db} {
164 DCHECK(env_ != nullptr);
165 DCHECK(db_ != nullptr);
166 }
167
168 Status CreateNewId(int64_t* id) TF_LOCKS_EXCLUDED(mu_) {
169 mutex_lock lock(mu_);
170 Status s;
171 SqliteStatement stmt;
172 TF_RETURN_IF_ERROR(db_->Prepare("INSERT INTO Ids (id) VALUES (?)", &stmt));
173 for (int i = 0; i < kMaxIdCollisions; ++i) {
174 int64_t tid = MakeRandomId();
175 stmt.BindInt(1, tid);
176 s = stmt.StepAndReset();
177 if (s.ok()) {
178 *id = tid;
179 break;
180 }
181 // SQLITE_CONSTRAINT maps to INVALID_ARGUMENT in sqlite.cc
182 if (s.code() != error::INVALID_ARGUMENT) break;
183 if (tier_ < kMaxIdTier) {
184 LOG(INFO) << "IdAllocator collision at tier " << tier_ << " (of "
185 << kMaxIdTier << ") so auto-adjusting to a higher tier";
186 ++tier_;
187 } else {
188 LOG(WARNING) << "IdAllocator (attempt #" << i << ") "
189 << "resulted in a collision at the highest tier; this "
190 "is problematic if it happens often; you can try "
191 "pruning the Ids table; you can also file a bug "
192 "asking for the ID space to be increased; otherwise "
193 "writes will gradually slow down over time until they "
194 "become impossible";
195 }
196 env_->SleepForMicroseconds((1 << i) * kIdCollisionDelayMicros);
197 }
198 return s;
199 }
200
201 private:
202 int64_t MakeRandomId() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
203 int64_t id = static_cast<int64_t>(random::New64() & kIdTiers[tier_]);
204 if (id == kAbsent) ++id;
205 return id;
206 }
207
208 mutex mu_;
209 Env* const env_;
210 Sqlite* const db_;
211 int tier_ TF_GUARDED_BY(mu_) = 0;
212
213 TF_DISALLOW_COPY_AND_ASSIGN(IdAllocator);
214};
215
216class GraphWriter {
217 public:
218 static Status Save(Sqlite* db, SqliteTransaction* txn, IdAllocator* ids,
219 GraphDef* graph, uint64 now, int64_t run_id,
220 int64_t* graph_id)
221 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) {
222 TF_RETURN_IF_ERROR(ids->CreateNewId(graph_id));
223 GraphWriter saver{db, txn, graph, now, *graph_id};
224 saver.MapNameToNodeId();
225 TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodeInputs(), "SaveNodeInputs");
226 TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodes(), "SaveNodes");
227 TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveGraph(run_id), "SaveGraph");
228 return OkStatus();
229 }
230
231 private:
232 GraphWriter(Sqlite* db, SqliteTransaction* txn, GraphDef* graph, uint64 now,
233 int64_t graph_id)
234 : db_(db), txn_(txn), graph_(graph), now_(now), graph_id_(graph_id) {}
235
236 void MapNameToNodeId() {
237 size_t toto = static_cast<size_t>(graph_->node_size());
238 name_copies_.reserve(toto);
239 name_to_node_id_.reserve(toto);
240 for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
241 // Copy name into memory region, since we call clear_name() later.
242 // Then wrap in StringPiece so we can compare slices without copy.
243 name_copies_.emplace_back(graph_->node(node_id).name());
244 name_to_node_id_.emplace(name_copies_.back(), node_id);
245 }
246 }
247
248 Status SaveNodeInputs() {
249 const char* sql = R"sql(
250 INSERT INTO NodeInputs (
251 graph_id,
252 node_id,
253 idx,
254 input_node_id,
255 input_node_idx,
256 is_control
257 ) VALUES (?, ?, ?, ?, ?, ?)
258 )sql";
259 SqliteStatement insert;
260 TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
261 for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
262 const NodeDef& node = graph_->node(node_id);
263 for (int idx = 0; idx < node.input_size(); ++idx) {
264 StringPiece name = node.input(idx);
265 int64_t input_node_id;
266 int64_t input_node_idx = 0;
267 int64_t is_control = 0;
268 size_t i = name.rfind(':');
269 if (i != StringPiece::npos) {
270 if (!strings::safe_strto64(name.substr(i + 1, name.size() - i - 1),
271 &input_node_idx)) {
272 return errors::DataLoss("Bad NodeDef.input: ", name);
273 }
274 name.remove_suffix(name.size() - i);
275 }
276 if (!name.empty() && name[0] == '^') {
277 name.remove_prefix(1);
278 is_control = 1;
279 }
280 auto e = name_to_node_id_.find(name);
281 if (e == name_to_node_id_.end()) {
282 return errors::DataLoss("Could not find node: ", name);
283 }
284 input_node_id = e->second;
285 insert.BindInt(1, graph_id_);
286 insert.BindInt(2, node_id);
287 insert.BindInt(3, idx);
288 insert.BindInt(4, input_node_id);
289 insert.BindInt(5, input_node_idx);
290 insert.BindInt(6, is_control);
291 unflushed_bytes_ += insert.size();
292 TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(),
293 " -> ", name);
294 TF_RETURN_IF_ERROR(MaybeFlush());
295 }
296 }
297 return OkStatus();
298 }
299
300 Status SaveNodes() {
301 const char* sql = R"sql(
302 INSERT INTO Nodes (
303 graph_id,
304 node_id,
305 node_name,
306 op,
307 device,
308 node_def)
309 VALUES (?, ?, ?, ?, ?, ?)
310 )sql";
311 SqliteStatement insert;
312 TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
313 for (int node_id = 0; node_id < graph_->node_size(); ++node_id) {
314 NodeDef* node = graph_->mutable_node(node_id);
315 insert.BindInt(1, graph_id_);
316 insert.BindInt(2, node_id);
317 insert.BindText(3, node->name());
318 insert.BindText(4, node->op());
319 insert.BindText(5, node->device());
320 node->clear_name();
321 node->clear_op();
322 node->clear_device();
323 node->clear_input();
324 string node_def;
325 if (node->SerializeToString(&node_def)) {
326 insert.BindBlobUnsafe(6, node_def);
327 }
328 unflushed_bytes_ += insert.size();
329 TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name());
330 TF_RETURN_IF_ERROR(MaybeFlush());
331 }
332 return OkStatus();
333 }
334
335 Status SaveGraph(int64_t run_id) {
336 const char* sql = R"sql(
337 INSERT OR REPLACE INTO Graphs (
338 run_id,
339 graph_id,
340 inserted_time,
341 graph_def
342 ) VALUES (?, ?, ?, ?)
343 )sql";
344 SqliteStatement insert;
345 TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert));
346 if (run_id != kAbsent) insert.BindInt(1, run_id);
347 insert.BindInt(2, graph_id_);
348 insert.BindDouble(3, DoubleTime(now_));
349 graph_->clear_node();
350 string graph_def;
351 if (graph_->SerializeToString(&graph_def)) {
352 insert.BindBlobUnsafe(4, graph_def);
353 }
354 return insert.StepAndReset();
355 }
356
357 Status MaybeFlush() {
358 if (unflushed_bytes_ >= kFlushBytes) {
359 TF_RETURN_WITH_CONTEXT_IF_ERROR(txn_->Commit(), "flushing ",
360 unflushed_bytes_, " bytes");
361 unflushed_bytes_ = 0;
362 }
363 return OkStatus();
364 }
365
366 Sqlite* const db_;
367 SqliteTransaction* const txn_;
368 uint64 unflushed_bytes_ = 0;
369 GraphDef* const graph_;
370 const uint64 now_;
371 const int64_t graph_id_;
372 std::vector<string> name_copies_;
373 std::unordered_map<StringPiece, int64_t, StringPieceHasher> name_to_node_id_;
374
375 TF_DISALLOW_COPY_AND_ASSIGN(GraphWriter);
376};
377
378/// \brief Run metadata manager.
379///
380/// This class gives us Tag IDs we can pass to SeriesWriter. In order
381/// to do that, rows are created in the Ids, Tags, Runs, Experiments,
382/// and Users tables.
383///
384/// This class is thread safe.
385class RunMetadata {
386 public:
387 RunMetadata(IdAllocator* ids, const string& experiment_name,
388 const string& run_name, const string& user_name)
389 : ids_{ids},
390 experiment_name_{experiment_name},
391 run_name_{run_name},
392 user_name_{user_name} {
393 DCHECK(ids_ != nullptr);
394 }
395
396 const string& experiment_name() { return experiment_name_; }
397 const string& run_name() { return run_name_; }
398 const string& user_name() { return user_name_; }
399
400 int64_t run_id() TF_LOCKS_EXCLUDED(mu_) {
401 mutex_lock lock(mu_);
402 return run_id_;
403 }
404
405 Status SetGraph(Sqlite* db, uint64 now, double computed_time,
406 std::unique_ptr<GraphDef> g) SQLITE_TRANSACTIONS_EXCLUDED(*db)
407 TF_LOCKS_EXCLUDED(mu_) {
408 int64_t run_id;
409 {
410 mutex_lock lock(mu_);
411 TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time));
412 run_id = run_id_;
413 }
414 int64_t graph_id;
415 SqliteTransaction txn(*db); // only to increase performance
416 TF_RETURN_IF_ERROR(
417 GraphWriter::Save(db, &txn, ids_, g.get(), now, run_id, &graph_id));
418 return txn.Commit();
419 }
420
421 Status GetTagId(Sqlite* db, uint64 now, double computed_time,
422 const string& tag_name, int64_t* tag_id,
423 const SummaryMetadata& metadata) TF_LOCKS_EXCLUDED(mu_) {
424 mutex_lock lock(mu_);
425 TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time));
426 auto e = tag_ids_.find(tag_name);
427 if (e != tag_ids_.end()) {
428 *tag_id = e->second;
429 return OkStatus();
430 }
431 TF_RETURN_IF_ERROR(ids_->CreateNewId(tag_id));
432 tag_ids_[tag_name] = *tag_id;
433 TF_RETURN_IF_ERROR(
434 SetDescription(db, *tag_id, metadata.summary_description()));
435 const char* sql = R"sql(
436 INSERT INTO Tags (
437 run_id,
438 tag_id,
439 tag_name,
440 inserted_time,
441 display_name,
442 plugin_name,
443 plugin_data
444 ) VALUES (
445 :run_id,
446 :tag_id,
447 :tag_name,
448 :inserted_time,
449 :display_name,
450 :plugin_name,
451 :plugin_data
452 )
453 )sql";
454 SqliteStatement insert;
455 TF_RETURN_IF_ERROR(db->Prepare(sql, &insert));
456 if (run_id_ != kAbsent) insert.BindInt(":run_id", run_id_);
457 insert.BindInt(":tag_id", *tag_id);
458 insert.BindTextUnsafe(":tag_name", tag_name);
459 insert.BindDouble(":inserted_time", DoubleTime(now));
460 insert.BindTextUnsafe(":display_name", metadata.display_name());
461 insert.BindTextUnsafe(":plugin_name", metadata.plugin_data().plugin_name());
462 insert.BindBlobUnsafe(":plugin_data", metadata.plugin_data().content());
463 return insert.StepAndReset();
464 }
465
466 private:
467 Status InitializeUser(Sqlite* db, uint64 now)
468 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
469 if (user_id_ != kAbsent || user_name_.empty()) return OkStatus();
470 const char* get_sql = R"sql(
471 SELECT user_id FROM Users WHERE user_name = ?
472 )sql";
473 SqliteStatement get;
474 TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get));
475 get.BindText(1, user_name_);
476 bool is_done;
477 TF_RETURN_IF_ERROR(get.Step(&is_done));
478 if (!is_done) {
479 user_id_ = get.ColumnInt(0);
480 return OkStatus();
481 }
482 TF_RETURN_IF_ERROR(ids_->CreateNewId(&user_id_));
483 const char* insert_sql = R"sql(
484 INSERT INTO Users (
485 user_id,
486 user_name,
487 inserted_time
488 ) VALUES (?, ?, ?)
489 )sql";
490 SqliteStatement insert;
491 TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
492 insert.BindInt(1, user_id_);
493 insert.BindText(2, user_name_);
494 insert.BindDouble(3, DoubleTime(now));
495 TF_RETURN_IF_ERROR(insert.StepAndReset());
496 return OkStatus();
497 }
498
499 Status InitializeExperiment(Sqlite* db, uint64 now, double computed_time)
500 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
501 if (experiment_name_.empty()) return OkStatus();
502 if (experiment_id_ == kAbsent) {
503 TF_RETURN_IF_ERROR(InitializeUser(db, now));
504 const char* get_sql = R"sql(
505 SELECT
506 experiment_id,
507 started_time
508 FROM
509 Experiments
510 WHERE
511 user_id IS ?
512 AND experiment_name = ?
513 )sql";
514 SqliteStatement get;
515 TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get));
516 if (user_id_ != kAbsent) get.BindInt(1, user_id_);
517 get.BindText(2, experiment_name_);
518 bool is_done;
519 TF_RETURN_IF_ERROR(get.Step(&is_done));
520 if (!is_done) {
521 experiment_id_ = get.ColumnInt(0);
522 experiment_started_time_ = get.ColumnInt(1);
523 } else {
524 TF_RETURN_IF_ERROR(ids_->CreateNewId(&experiment_id_));
525 experiment_started_time_ = computed_time;
526 const char* insert_sql = R"sql(
527 INSERT INTO Experiments (
528 user_id,
529 experiment_id,
530 experiment_name,
531 inserted_time,
532 started_time,
533 is_watching
534 ) VALUES (?, ?, ?, ?, ?, ?)
535 )sql";
536 SqliteStatement insert;
537 TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
538 if (user_id_ != kAbsent) insert.BindInt(1, user_id_);
539 insert.BindInt(2, experiment_id_);
540 insert.BindText(3, experiment_name_);
541 insert.BindDouble(4, DoubleTime(now));
542 insert.BindDouble(5, computed_time);
543 insert.BindInt(6, 0);
544 TF_RETURN_IF_ERROR(insert.StepAndReset());
545 }
546 }
547 if (computed_time < experiment_started_time_) {
548 experiment_started_time_ = computed_time;
549 const char* update_sql = R"sql(
550 UPDATE
551 Experiments
552 SET
553 started_time = ?
554 WHERE
555 experiment_id = ?
556 )sql";
557 SqliteStatement update;
558 TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update));
559 update.BindDouble(1, computed_time);
560 update.BindInt(2, experiment_id_);
561 TF_RETURN_IF_ERROR(update.StepAndReset());
562 }
563 return OkStatus();
564 }
565
566 Status InitializeRun(Sqlite* db, uint64 now, double computed_time)
567 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
568 if (run_name_.empty()) return OkStatus();
569 TF_RETURN_IF_ERROR(InitializeExperiment(db, now, computed_time));
570 if (run_id_ == kAbsent) {
571 TF_RETURN_IF_ERROR(ids_->CreateNewId(&run_id_));
572 run_started_time_ = computed_time;
573 const char* insert_sql = R"sql(
574 INSERT OR REPLACE INTO Runs (
575 experiment_id,
576 run_id,
577 run_name,
578 inserted_time,
579 started_time
580 ) VALUES (?, ?, ?, ?, ?)
581 )sql";
582 SqliteStatement insert;
583 TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert));
584 if (experiment_id_ != kAbsent) insert.BindInt(1, experiment_id_);
585 insert.BindInt(2, run_id_);
586 insert.BindText(3, run_name_);
587 insert.BindDouble(4, DoubleTime(now));
588 insert.BindDouble(5, computed_time);
589 TF_RETURN_IF_ERROR(insert.StepAndReset());
590 }
591 if (computed_time < run_started_time_) {
592 run_started_time_ = computed_time;
593 const char* update_sql = R"sql(
594 UPDATE
595 Runs
596 SET
597 started_time = ?
598 WHERE
599 run_id = ?
600 )sql";
601 SqliteStatement update;
602 TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update));
603 update.BindDouble(1, computed_time);
604 update.BindInt(2, run_id_);
605 TF_RETURN_IF_ERROR(update.StepAndReset());
606 }
607 return OkStatus();
608 }
609
610 mutex mu_;
611 IdAllocator* const ids_;
612 const string experiment_name_;
613 const string run_name_;
614 const string user_name_;
615 int64_t experiment_id_ TF_GUARDED_BY(mu_) = kAbsent;
616 int64_t run_id_ TF_GUARDED_BY(mu_) = kAbsent;
617 int64_t user_id_ TF_GUARDED_BY(mu_) = kAbsent;
618 double experiment_started_time_ TF_GUARDED_BY(mu_) = 0.0;
619 double run_started_time_ TF_GUARDED_BY(mu_) = 0.0;
620 std::unordered_map<string, int64_t> tag_ids_ TF_GUARDED_BY(mu_);
621
622 TF_DISALLOW_COPY_AND_ASSIGN(RunMetadata);
623};
624
625/// \brief Tensor writer for a single series, e.g. Tag.
626///
627/// This class is thread safe.
628class SeriesWriter {
629 public:
630 SeriesWriter(int64_t series, RunMetadata* meta)
631 : series_{series}, meta_{meta} {
632 DCHECK(series_ > 0);
633 }
634
635 Status Append(Sqlite* db, int64_t step, uint64 now, double computed_time,
636 const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
637 TF_LOCKS_EXCLUDED(mu_) {
638 mutex_lock lock(mu_);
639 if (rowids_.empty()) {
640 Status s = Reserve(db, t);
641 if (!s.ok()) {
642 rowids_.clear();
643 return s;
644 }
645 }
646 int64_t rowid = rowids_.front();
647 Status s = Write(db, rowid, step, computed_time, t);
648 if (s.ok()) {
649 ++count_;
650 }
651 rowids_.pop_front();
652 return s;
653 }
654
655 Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
656 TF_LOCKS_EXCLUDED(mu_) {
657 mutex_lock lock(mu_);
658 // Delete unused pre-allocated Tensors.
659 if (!rowids_.empty()) {
660 SqliteTransaction txn(*db);
661 const char* sql = R"sql(
662 DELETE FROM Tensors WHERE rowid = ?
663 )sql";
664 SqliteStatement deleter;
665 TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter));
666 for (size_t i = count_; i < rowids_.size(); ++i) {
667 deleter.BindInt(1, rowids_.front());
668 TF_RETURN_IF_ERROR(deleter.StepAndReset());
669 rowids_.pop_front();
670 }
671 TF_RETURN_IF_ERROR(txn.Commit());
672 rowids_.clear();
673 }
674 return OkStatus();
675 }
676
677 private:
678 Status Write(Sqlite* db, int64_t rowid, int64_t step, double computed_time,
679 const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) {
680 if (t.dtype() == DT_STRING) {
681 if (t.dims() == 0) {
682 return Update(db, step, computed_time, t, t.scalar<tstring>()(), rowid);
683 } else {
684 SqliteTransaction txn(*db);
685 TF_RETURN_IF_ERROR(
686 Update(db, step, computed_time, t, StringPiece(), rowid));
687 TF_RETURN_IF_ERROR(UpdateNdString(db, t, rowid));
688 return txn.Commit();
689 }
690 } else {
691 return Update(db, step, computed_time, t, t.tensor_data(), rowid);
692 }
693 }
694
695 Status Update(Sqlite* db, int64_t step, double computed_time, const Tensor& t,
696 const StringPiece& data, int64_t rowid) {
697 const char* sql = R"sql(
698 UPDATE OR REPLACE
699 Tensors
700 SET
701 step = ?,
702 computed_time = ?,
703 dtype = ?,
704 shape = ?,
705 data = ?
706 WHERE
707 rowid = ?
708 )sql";
709 SqliteStatement stmt;
710 TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt));
711 stmt.BindInt(1, step);
712 stmt.BindDouble(2, computed_time);
713 stmt.BindInt(3, t.dtype());
714 stmt.BindText(4, StringifyShape(t.shape()));
715 stmt.BindBlobUnsafe(5, data);
716 stmt.BindInt(6, rowid);
717 TF_RETURN_IF_ERROR(stmt.StepAndReset());
718 return OkStatus();
719 }
720
721 Status UpdateNdString(Sqlite* db, const Tensor& t, int64_t tensor_rowid)
722 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) {
723 DCHECK_EQ(t.dtype(), DT_STRING);
724 DCHECK_GT(t.dims(), 0);
725 const char* deleter_sql = R"sql(
726 DELETE FROM TensorStrings WHERE tensor_rowid = ?
727 )sql";
728 SqliteStatement deleter;
729 TF_RETURN_IF_ERROR(db->Prepare(deleter_sql, &deleter));
730 deleter.BindInt(1, tensor_rowid);
731 TF_RETURN_WITH_CONTEXT_IF_ERROR(deleter.StepAndReset(), tensor_rowid);
732 const char* inserter_sql = R"sql(
733 INSERT INTO TensorStrings (
734 tensor_rowid,
735 idx,
736 data
737 ) VALUES (?, ?, ?)
738 )sql";
739 SqliteStatement inserter;
740 TF_RETURN_IF_ERROR(db->Prepare(inserter_sql, &inserter));
741 auto flat = t.flat<tstring>();
742 for (int64_t i = 0; i < flat.size(); ++i) {
743 inserter.BindInt(1, tensor_rowid);
744 inserter.BindInt(2, i);
745 inserter.BindBlobUnsafe(3, flat(i));
746 TF_RETURN_WITH_CONTEXT_IF_ERROR(inserter.StepAndReset(), "i=", i);
747 }
748 return OkStatus();
749 }
750
751 Status Reserve(Sqlite* db, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db)
752 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
753 SqliteTransaction txn(*db); // only for performance
754 unflushed_bytes_ = 0;
755 if (t.dtype() == DT_STRING) {
756 if (t.dims() == 0) {
757 TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar<tstring>()().size()));
758 } else {
759 TF_RETURN_IF_ERROR(ReserveTensors(db, &txn, kReserveMinBytes));
760 }
761 } else {
762 TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.tensor_data().size()));
763 }
764 return txn.Commit();
765 }
766
767 Status ReserveData(Sqlite* db, SqliteTransaction* txn, size_t size)
768 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
769 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
770 int64_t space =
771 static_cast<int64_t>(static_cast<double>(size) * kReserveMultiplier);
772 if (space < kReserveMinBytes) space = kReserveMinBytes;
773 return ReserveTensors(db, txn, space);
774 }
775
776 Status ReserveTensors(Sqlite* db, SqliteTransaction* txn,
777 int64_t reserved_bytes)
778 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
779 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
780 const char* sql = R"sql(
781 INSERT INTO Tensors (
782 series,
783 data
784 ) VALUES (?, ZEROBLOB(?))
785 )sql";
786 SqliteStatement insert;
787 TF_RETURN_IF_ERROR(db->Prepare(sql, &insert));
788 // TODO(jart): Maybe preallocate index pages by setting step. This
789 // is tricky because UPDATE OR REPLACE can have a side
790 // effect of deleting preallocated rows.
791 for (int64_t i = 0; i < kPreallocateRows; ++i) {
792 insert.BindInt(1, series_);
793 insert.BindInt(2, reserved_bytes);
794 TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i);
795 rowids_.push_back(db->last_insert_rowid());
796 unflushed_bytes_ += reserved_bytes;
797 TF_RETURN_IF_ERROR(MaybeFlush(db, txn));
798 }
799 return OkStatus();
800 }
801
802 Status MaybeFlush(Sqlite* db, SqliteTransaction* txn)
803 SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db)
804 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
805 if (unflushed_bytes_ >= kFlushBytes) {
806 TF_RETURN_WITH_CONTEXT_IF_ERROR(txn->Commit(), "flushing ",
807 unflushed_bytes_, " bytes");
808 unflushed_bytes_ = 0;
809 }
810 return OkStatus();
811 }
812
813 mutex mu_;
814 const int64_t series_;
815 RunMetadata* const meta_;
816 uint64 count_ TF_GUARDED_BY(mu_) = 0;
817 std::deque<int64_t> rowids_ TF_GUARDED_BY(mu_);
818 uint64 unflushed_bytes_ TF_GUARDED_BY(mu_) = 0;
819
820 TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter);
821};
822
823/// \brief Tensor writer for a single Run.
824///
825/// This class farms out tensors to SeriesWriter instances. It also
826/// keeps track of whether or not someone is watching the TensorBoard
827/// GUI, so it can avoid writes when possible.
828///
829/// This class is thread safe.
830class RunWriter {
831 public:
832 explicit RunWriter(RunMetadata* meta) : meta_{meta} {}
833
834 Status Append(Sqlite* db, int64_t tag_id, int64_t step, uint64 now,
835 double computed_time, const Tensor& t)
836 SQLITE_TRANSACTIONS_EXCLUDED(*db) TF_LOCKS_EXCLUDED(mu_) {
837 SeriesWriter* writer = GetSeriesWriter(tag_id);
838 return writer->Append(db, step, now, computed_time, t);
839 }
840
841 Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db)
842 TF_LOCKS_EXCLUDED(mu_) {
843 mutex_lock lock(mu_);
844 if (series_writers_.empty()) return OkStatus();
845 for (auto i = series_writers_.begin(); i != series_writers_.end(); ++i) {
846 if (!i->second) continue;
847 TF_RETURN_WITH_CONTEXT_IF_ERROR(i->second->Finish(db),
848 "finish tag_id=", i->first);
849 i->second.reset();
850 }
851 return OkStatus();
852 }
853
854 private:
855 SeriesWriter* GetSeriesWriter(int64_t tag_id) TF_LOCKS_EXCLUDED(mu_) {
856 mutex_lock sl(mu_);
857 auto spot = series_writers_.find(tag_id);
858 if (spot == series_writers_.end()) {
859 SeriesWriter* writer = new SeriesWriter(tag_id, meta_);
860 series_writers_[tag_id].reset(writer);
861 return writer;
862 } else {
863 return spot->second.get();
864 }
865 }
866
867 mutex mu_;
868 RunMetadata* const meta_;
869 std::unordered_map<int64_t, std::unique_ptr<SeriesWriter>> series_writers_
870 TF_GUARDED_BY(mu_);
871
872 TF_DISALLOW_COPY_AND_ASSIGN(RunWriter);
873};
874
875/// \brief SQLite implementation of SummaryWriterInterface.
876///
877/// This class is thread safe.
878class SummaryDbWriter : public SummaryWriterInterface {
879 public:
880 SummaryDbWriter(Env* env, Sqlite* db, const string& experiment_name,
881 const string& run_name, const string& user_name)
882 : SummaryWriterInterface(),
883 env_{env},
884 db_{db},
885 ids_{env_, db_},
886 meta_{&ids_, experiment_name, run_name, user_name},
887 run_{&meta_} {
888 DCHECK(env_ != nullptr);
889 db_->Ref();
890 }
891
892 ~SummaryDbWriter() override {
893 core::ScopedUnref unref(db_);
894 Status s = run_.Finish(db_);
895 if (!s.ok()) {
896 // TODO(jart): Retry on transient errors here.
897 LOG(ERROR) << s.ToString();
898 }
899 int64_t run_id = meta_.run_id();
900 if (run_id == kAbsent) return;
901 const char* sql = R"sql(
902 UPDATE Runs SET finished_time = ? WHERE run_id = ?
903 )sql";
904 SqliteStatement update;
905 s = db_->Prepare(sql, &update);
906 if (s.ok()) {
907 update.BindDouble(1, DoubleTime(env_->NowMicros()));
908 update.BindInt(2, run_id);
909 s = update.StepAndReset();
910 }
911 if (!s.ok()) {
912 LOG(ERROR) << "Failed to set Runs[" << run_id
913 << "].finish_time: " << s.ToString();
914 }
915 }
916
917 Status Flush() override { return OkStatus(); }
918
919 Status WriteTensor(int64_t global_step, Tensor t, const string& tag,
920 const string& serialized_metadata) override {
921 TF_RETURN_IF_ERROR(CheckSupportedType(t));
922 SummaryMetadata metadata;
923 if (!metadata.ParseFromString(serialized_metadata)) {
924 return errors::InvalidArgument("Bad serialized_metadata");
925 }
926 return Write(global_step, t, tag, metadata);
927 }
928
929 Status WriteScalar(int64_t global_step, Tensor t,
930 const string& tag) override {
931 TF_RETURN_IF_ERROR(CheckSupportedType(t));
932 SummaryMetadata metadata;
933 PatchPluginName(&metadata, kScalarPluginName);
934 return Write(global_step, AsScalar(t), tag, metadata);
935 }
936
937 Status WriteGraph(int64_t global_step, std::unique_ptr<GraphDef> g) override {
938 uint64 now = env_->NowMicros();
939 return meta_.SetGraph(db_, now, DoubleTime(now), std::move(g));
940 }
941
942 Status WriteEvent(std::unique_ptr<Event> e) override {
943 return MigrateEvent(std::move(e));
944 }
945
946 Status WriteHistogram(int64_t global_step, Tensor t,
947 const string& tag) override {
948 uint64 now = env_->NowMicros();
949 std::unique_ptr<Event> e{new Event};
950 e->set_step(global_step);
951 e->set_wall_time(DoubleTime(now));
952 TF_RETURN_IF_ERROR(
953 AddTensorAsHistogramToSummary(t, tag, e->mutable_summary()));
954 return MigrateEvent(std::move(e));
955 }
956
957 Status WriteImage(int64_t global_step, Tensor t, const string& tag,
958 int max_images, Tensor bad_color) override {
959 uint64 now = env_->NowMicros();
960 std::unique_ptr<Event> e{new Event};
961 e->set_step(global_step);
962 e->set_wall_time(DoubleTime(now));
963 TF_RETURN_IF_ERROR(AddTensorAsImageToSummary(t, tag, max_images, bad_color,
964 e->mutable_summary()));
965 return MigrateEvent(std::move(e));
966 }
967
968 Status WriteAudio(int64_t global_step, Tensor t, const string& tag,
969 int max_outputs, float sample_rate) override {
970 uint64 now = env_->NowMicros();
971 std::unique_ptr<Event> e{new Event};
972 e->set_step(global_step);
973 e->set_wall_time(DoubleTime(now));
974 TF_RETURN_IF_ERROR(AddTensorAsAudioToSummary(
975 t, tag, max_outputs, sample_rate, e->mutable_summary()));
976 return MigrateEvent(std::move(e));
977 }
978
979 string DebugString() const override { return "SummaryDbWriter"; }
980
981 private:
982 Status Write(int64_t step, const Tensor& t, const string& tag,
983 const SummaryMetadata& metadata) {
984 uint64 now = env_->NowMicros();
985 double computed_time = DoubleTime(now);
986 int64_t tag_id;
987 TF_RETURN_IF_ERROR(
988 meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata));
989 TF_RETURN_WITH_CONTEXT_IF_ERROR(
990 run_.Append(db_, tag_id, step, now, computed_time, t),
991 meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(),
992 "/", tag, "@", step);
993 return OkStatus();
994 }
995
996 Status MigrateEvent(std::unique_ptr<Event> e) {
997 switch (e->what_case()) {
998 case Event::WhatCase::kSummary: {
999 uint64 now = env_->NowMicros();
1000 auto summaries = e->mutable_summary();
1001 for (int i = 0; i < summaries->value_size(); ++i) {
1002 Summary::Value* value = summaries->mutable_value(i);
1003 TF_RETURN_WITH_CONTEXT_IF_ERROR(
1004 MigrateSummary(e.get(), value, now), meta_.user_name(), "/",
1005 meta_.experiment_name(), "/", meta_.run_name(), "/", value->tag(),
1006 "@", e->step());
1007 }
1008 break;
1009 }
1010 case Event::WhatCase::kGraphDef:
1011 TF_RETURN_WITH_CONTEXT_IF_ERROR(
1012 MigrateGraph(e.get(), e->graph_def()), meta_.user_name(), "/",
1013 meta_.experiment_name(), "/", meta_.run_name(), "/__graph__@",
1014 e->step());
1015 break;
1016 default:
1017 // TODO(@jart): Handle other stuff.
1018 break;
1019 }
1020 return OkStatus();
1021 }
1022
1023 Status MigrateGraph(const Event* e, const string& graph_def) {
1024 uint64 now = env_->NowMicros();
1025 std::unique_ptr<GraphDef> graph{new GraphDef};
1026 if (!ParseProtoUnlimited(graph.get(), graph_def)) {
1027 return errors::InvalidArgument("bad proto");
1028 }
1029 return meta_.SetGraph(db_, now, e->wall_time(), std::move(graph));
1030 }
1031
1032 Status MigrateSummary(const Event* e, Summary::Value* s, uint64 now) {
1033 switch (s->value_case()) {
1034 case Summary::Value::ValueCase::kTensor:
1035 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateTensor(e, s, now), "tensor");
1036 break;
1037 case Summary::Value::ValueCase::kSimpleValue:
1038 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateScalar(e, s, now), "scalar");
1039 break;
1040 case Summary::Value::ValueCase::kHisto:
1041 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateHistogram(e, s, now), "histo");
1042 break;
1043 case Summary::Value::ValueCase::kImage:
1044 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateImage(e, s, now), "image");
1045 break;
1046 case Summary::Value::ValueCase::kAudio:
1047 TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateAudio(e, s, now), "audio");
1048 break;
1049 default:
1050 break;
1051 }
1052 return OkStatus();
1053 }
1054
1055 Status MigrateTensor(const Event* e, Summary::Value* s, uint64 now) {
1056 Tensor t;
1057 if (!t.FromProto(s->tensor())) return errors::InvalidArgument("bad proto");
1058 TF_RETURN_IF_ERROR(CheckSupportedType(t));
1059 int64_t tag_id;
1060 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1061 &tag_id, s->metadata()));
1062 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1063 }
1064
1065 // TODO(jart): Refactor Summary -> Tensor logic into separate file.
1066
1067 Status MigrateScalar(const Event* e, Summary::Value* s, uint64 now) {
1068 // See tensorboard/plugins/scalar/summary.py and data_compat.py
1069 Tensor t{DT_FLOAT, {}};
1070 t.scalar<float>()() = s->simple_value();
1071 int64_t tag_id;
1072 PatchPluginName(s->mutable_metadata(), kScalarPluginName);
1073 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1074 &tag_id, s->metadata()));
1075 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1076 }
1077
1078 Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) {
1079 const HistogramProto& histo = s->histo();
1080 int k = histo.bucket_size();
1081 if (k != histo.bucket_limit_size()) {
1082 return errors::InvalidArgument("size mismatch");
1083 }
1084 // See tensorboard/plugins/histogram/summary.py and data_compat.py
1085 Tensor t{DT_DOUBLE, {k, 3}};
1086 auto data = t.flat<double>();
1087 for (int i = 0, j = 0; i < k; ++i) {
1088 // TODO(nickfelt): reconcile with TensorBoard's data_compat.py
1089 // From summary.proto
1090 // Parallel arrays encoding the bucket boundaries and the bucket values.
1091 // bucket(i) is the count for the bucket i. The range for
1092 // a bucket is:
1093 // i == 0: -DBL_MAX .. bucket_limit(0)
1094 // i != 0: bucket_limit(i-1) .. bucket_limit(i)
1095 double left_edge = (i == 0) ? std::numeric_limits<double>::min()
1096 : histo.bucket_limit(i - 1);
1097
1098 data(j++) = left_edge;
1099 data(j++) = histo.bucket_limit(i);
1100 data(j++) = histo.bucket(i);
1101 }
1102 int64_t tag_id;
1103 PatchPluginName(s->mutable_metadata(), kHistogramPluginName);
1104 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1105 &tag_id, s->metadata()));
1106 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1107 }
1108
1109 Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) {
1110 // See tensorboard/plugins/image/summary.py and data_compat.py
1111 Tensor t{DT_STRING, {3}};
1112 auto img = s->mutable_image();
1113 t.flat<tstring>()(0) = strings::StrCat(img->width());
1114 t.flat<tstring>()(1) = strings::StrCat(img->height());
1115 t.flat<tstring>()(2) = std::move(*img->mutable_encoded_image_string());
1116 int64_t tag_id;
1117 PatchPluginName(s->mutable_metadata(), kImagePluginName);
1118 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1119 &tag_id, s->metadata()));
1120 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1121 }
1122
1123 Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) {
1124 // See tensorboard/plugins/audio/summary.py and data_compat.py
1125 Tensor t{DT_STRING, {1, 2}};
1126 auto wav = s->mutable_audio();
1127 t.flat<tstring>()(0) = std::move(*wav->mutable_encoded_audio_string());
1128 t.flat<tstring>()(1) = "";
1129 int64_t tag_id;
1130 PatchPluginName(s->mutable_metadata(), kAudioPluginName);
1131 TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(),
1132 &tag_id, s->metadata()));
1133 return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t);
1134 }
1135
1136 Env* const env_;
1137 Sqlite* const db_;
1138 IdAllocator ids_;
1139 RunMetadata meta_;
1140 RunWriter run_;
1141};
1142
1143} // namespace
1144
1145Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name,
1146 const string& run_name, const string& user_name,
1147 Env* env, SummaryWriterInterface** result) {
1148 *result = new SummaryDbWriter(env, db, experiment_name, run_name, user_name);
1149 return OkStatus();
1150}
1151
1152} // namespace tensorflow
1153