1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/summary/summary_db_writer.h" |
16 | |
17 | #include <deque> |
18 | |
19 | #include "tensorflow/core/summary/summary_converter.h" |
20 | #include "tensorflow/core/framework/graph.pb.h" |
21 | #include "tensorflow/core/framework/node_def.pb.h" |
22 | #include "tensorflow/core/framework/register_types.h" |
23 | #include "tensorflow/core/framework/summary.pb.h" |
24 | #include "tensorflow/core/lib/core/stringpiece.h" |
25 | #include "tensorflow/core/lib/db/sqlite.h" |
26 | #include "tensorflow/core/lib/random/random.h" |
27 | #include "tensorflow/core/util/event.pb.h" |
28 | |
29 | // TODO(jart): Break this up into multiple files with excellent unit tests. |
30 | // TODO(jart): Make decision to write in separate op. |
31 | // TODO(jart): Add really good busy handling. |
32 | |
33 | // clang-format off |
34 | #define CALL_SUPPORTED_TYPES(m) \ |
35 | TF_CALL_tstring(m) \ |
36 | TF_CALL_half(m) \ |
37 | TF_CALL_float(m) \ |
38 | TF_CALL_double(m) \ |
39 | TF_CALL_complex64(m) \ |
40 | TF_CALL_complex128(m) \ |
41 | TF_CALL_int8(m) \ |
42 | TF_CALL_int16(m) \ |
43 | TF_CALL_int32(m) \ |
44 | TF_CALL_int64(m) \ |
45 | TF_CALL_uint8(m) \ |
46 | TF_CALL_uint16(m) \ |
47 | TF_CALL_uint32(m) \ |
48 | TF_CALL_uint64(m) |
49 | // clang-format on |
50 | |
51 | namespace tensorflow { |
52 | namespace { |
53 | |
54 | // https://www.sqlite.org/fileformat.html#record_format |
55 | const uint64 kIdTiers[] = { |
56 | 0x7fffffULL, // 23-bit (3 bytes on disk) |
57 | 0x7fffffffULL, // 31-bit (4 bytes on disk) |
58 | 0x7fffffffffffULL, // 47-bit (5 bytes on disk) |
59 | // remaining bits for future use |
60 | }; |
61 | const int kMaxIdTier = sizeof(kIdTiers) / sizeof(uint64); |
62 | const int kIdCollisionDelayMicros = 10; |
63 | const int kMaxIdCollisions = 21; // sum(2**i*10µs for i in range(21))~=21s |
64 | const int64_t kAbsent = 0LL; |
65 | |
66 | const char* kScalarPluginName = "scalars" ; |
67 | const char* kImagePluginName = "images" ; |
68 | const char* kAudioPluginName = "audio" ; |
69 | const char* kHistogramPluginName = "histograms" ; |
70 | |
71 | const int64_t kReserveMinBytes = 32; |
72 | const double kReserveMultiplier = 1.5; |
73 | const int64_t kPreallocateRows = 1000; |
74 | |
75 | // Flush is a misnomer because what we're actually doing is having lots |
76 | // of commits inside any SqliteTransaction that writes potentially |
77 | // hundreds of megs but doesn't need the transaction to maintain its |
78 | // invariants. This ensures the WAL read penalty is small and might |
79 | // allow writers in other processes a chance to schedule. |
80 | const uint64 kFlushBytes = 1024 * 1024; |
81 | |
82 | double DoubleTime(uint64 micros) { |
83 | // TODO(@jart): Follow precise definitions for time laid out in schema. |
84 | // TODO(@jart): Use monotonic clock from gRPC codebase. |
85 | return static_cast<double>(micros) / 1.0e6; |
86 | } |
87 | |
88 | string StringifyShape(const TensorShape& shape) { |
89 | string result; |
90 | bool first = true; |
91 | for (const auto& dim : shape) { |
92 | if (first) { |
93 | first = false; |
94 | } else { |
95 | strings::StrAppend(&result, "," ); |
96 | } |
97 | strings::StrAppend(&result, dim.size); |
98 | } |
99 | return result; |
100 | } |
101 | |
102 | Status CheckSupportedType(const Tensor& t) { |
103 | #define CASE(T) \ |
104 | case DataTypeToEnum<T>::value: \ |
105 | break; |
106 | switch (t.dtype()) { |
107 | CALL_SUPPORTED_TYPES(CASE) |
108 | default: |
109 | return errors::Unimplemented(DataTypeString(t.dtype()), |
110 | " tensors unsupported on platform" ); |
111 | } |
112 | return OkStatus(); |
113 | #undef CASE |
114 | } |
115 | |
116 | Tensor AsScalar(const Tensor& t) { |
117 | Tensor t2{t.dtype(), {}}; |
118 | #define CASE(T) \ |
119 | case DataTypeToEnum<T>::value: \ |
120 | t2.scalar<T>()() = t.flat<T>()(0); \ |
121 | break; |
122 | switch (t.dtype()) { |
123 | CALL_SUPPORTED_TYPES(CASE) |
124 | default: |
125 | t2 = {DT_FLOAT, {}}; |
126 | t2.scalar<float>()() = NAN; |
127 | break; |
128 | } |
129 | return t2; |
130 | #undef CASE |
131 | } |
132 | |
133 | void PatchPluginName(SummaryMetadata* metadata, const char* name) { |
134 | if (metadata->plugin_data().plugin_name().empty()) { |
135 | metadata->mutable_plugin_data()->set_plugin_name(name); |
136 | } |
137 | } |
138 | |
139 | Status SetDescription(Sqlite* db, int64_t id, const StringPiece& markdown) { |
140 | const char* sql = R"sql( |
141 | INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?) |
142 | )sql" ; |
143 | SqliteStatement insert_desc; |
144 | TF_RETURN_IF_ERROR(db->Prepare(sql, &insert_desc)); |
145 | insert_desc.BindInt(1, id); |
146 | insert_desc.BindText(2, markdown); |
147 | return insert_desc.StepAndReset(); |
148 | } |
149 | |
150 | /// \brief Generates unique IDs randomly in the [1,2**63-1] range. |
151 | /// |
152 | /// This class starts off generating IDs in the [1,2**23-1] range, |
153 | /// because it's human friendly and occupies 4 bytes max on disk with |
154 | /// SQLite's zigzag varint encoding. Then, each time a collision |
155 | /// happens, the random space is increased by 8 bits. |
156 | /// |
157 | /// This class uses exponential back-off so writes gradually slow down |
158 | /// as IDs become exhausted but reads are still possible. |
159 | /// |
160 | /// This class is thread safe. |
161 | class IdAllocator { |
162 | public: |
163 | IdAllocator(Env* env, Sqlite* db) : env_{env}, db_{db} { |
164 | DCHECK(env_ != nullptr); |
165 | DCHECK(db_ != nullptr); |
166 | } |
167 | |
168 | Status CreateNewId(int64_t* id) TF_LOCKS_EXCLUDED(mu_) { |
169 | mutex_lock lock(mu_); |
170 | Status s; |
171 | SqliteStatement stmt; |
172 | TF_RETURN_IF_ERROR(db_->Prepare("INSERT INTO Ids (id) VALUES (?)" , &stmt)); |
173 | for (int i = 0; i < kMaxIdCollisions; ++i) { |
174 | int64_t tid = MakeRandomId(); |
175 | stmt.BindInt(1, tid); |
176 | s = stmt.StepAndReset(); |
177 | if (s.ok()) { |
178 | *id = tid; |
179 | break; |
180 | } |
181 | // SQLITE_CONSTRAINT maps to INVALID_ARGUMENT in sqlite.cc |
182 | if (s.code() != error::INVALID_ARGUMENT) break; |
183 | if (tier_ < kMaxIdTier) { |
184 | LOG(INFO) << "IdAllocator collision at tier " << tier_ << " (of " |
185 | << kMaxIdTier << ") so auto-adjusting to a higher tier" ; |
186 | ++tier_; |
187 | } else { |
188 | LOG(WARNING) << "IdAllocator (attempt #" << i << ") " |
189 | << "resulted in a collision at the highest tier; this " |
190 | "is problematic if it happens often; you can try " |
191 | "pruning the Ids table; you can also file a bug " |
192 | "asking for the ID space to be increased; otherwise " |
193 | "writes will gradually slow down over time until they " |
194 | "become impossible" ; |
195 | } |
196 | env_->SleepForMicroseconds((1 << i) * kIdCollisionDelayMicros); |
197 | } |
198 | return s; |
199 | } |
200 | |
201 | private: |
202 | int64_t MakeRandomId() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
203 | int64_t id = static_cast<int64_t>(random::New64() & kIdTiers[tier_]); |
204 | if (id == kAbsent) ++id; |
205 | return id; |
206 | } |
207 | |
208 | mutex mu_; |
209 | Env* const env_; |
210 | Sqlite* const db_; |
211 | int tier_ TF_GUARDED_BY(mu_) = 0; |
212 | |
213 | TF_DISALLOW_COPY_AND_ASSIGN(IdAllocator); |
214 | }; |
215 | |
216 | class GraphWriter { |
217 | public: |
218 | static Status Save(Sqlite* db, SqliteTransaction* txn, IdAllocator* ids, |
219 | GraphDef* graph, uint64 now, int64_t run_id, |
220 | int64_t* graph_id) |
221 | SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) { |
222 | TF_RETURN_IF_ERROR(ids->CreateNewId(graph_id)); |
223 | GraphWriter saver{db, txn, graph, now, *graph_id}; |
224 | saver.MapNameToNodeId(); |
225 | TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodeInputs(), "SaveNodeInputs" ); |
226 | TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveNodes(), "SaveNodes" ); |
227 | TF_RETURN_WITH_CONTEXT_IF_ERROR(saver.SaveGraph(run_id), "SaveGraph" ); |
228 | return OkStatus(); |
229 | } |
230 | |
231 | private: |
232 | GraphWriter(Sqlite* db, SqliteTransaction* txn, GraphDef* graph, uint64 now, |
233 | int64_t graph_id) |
234 | : db_(db), txn_(txn), graph_(graph), now_(now), graph_id_(graph_id) {} |
235 | |
236 | void MapNameToNodeId() { |
237 | size_t toto = static_cast<size_t>(graph_->node_size()); |
238 | name_copies_.reserve(toto); |
239 | name_to_node_id_.reserve(toto); |
240 | for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { |
241 | // Copy name into memory region, since we call clear_name() later. |
242 | // Then wrap in StringPiece so we can compare slices without copy. |
243 | name_copies_.emplace_back(graph_->node(node_id).name()); |
244 | name_to_node_id_.emplace(name_copies_.back(), node_id); |
245 | } |
246 | } |
247 | |
248 | Status SaveNodeInputs() { |
249 | const char* sql = R"sql( |
250 | INSERT INTO NodeInputs ( |
251 | graph_id, |
252 | node_id, |
253 | idx, |
254 | input_node_id, |
255 | input_node_idx, |
256 | is_control |
257 | ) VALUES (?, ?, ?, ?, ?, ?) |
258 | )sql" ; |
259 | SqliteStatement insert; |
260 | TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert)); |
261 | for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { |
262 | const NodeDef& node = graph_->node(node_id); |
263 | for (int idx = 0; idx < node.input_size(); ++idx) { |
264 | StringPiece name = node.input(idx); |
265 | int64_t input_node_id; |
266 | int64_t input_node_idx = 0; |
267 | int64_t is_control = 0; |
268 | size_t i = name.rfind(':'); |
269 | if (i != StringPiece::npos) { |
270 | if (!strings::safe_strto64(name.substr(i + 1, name.size() - i - 1), |
271 | &input_node_idx)) { |
272 | return errors::DataLoss("Bad NodeDef.input: " , name); |
273 | } |
274 | name.remove_suffix(name.size() - i); |
275 | } |
276 | if (!name.empty() && name[0] == '^') { |
277 | name.remove_prefix(1); |
278 | is_control = 1; |
279 | } |
280 | auto e = name_to_node_id_.find(name); |
281 | if (e == name_to_node_id_.end()) { |
282 | return errors::DataLoss("Could not find node: " , name); |
283 | } |
284 | input_node_id = e->second; |
285 | insert.BindInt(1, graph_id_); |
286 | insert.BindInt(2, node_id); |
287 | insert.BindInt(3, idx); |
288 | insert.BindInt(4, input_node_id); |
289 | insert.BindInt(5, input_node_idx); |
290 | insert.BindInt(6, is_control); |
291 | unflushed_bytes_ += insert.size(); |
292 | TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node.name(), |
293 | " -> " , name); |
294 | TF_RETURN_IF_ERROR(MaybeFlush()); |
295 | } |
296 | } |
297 | return OkStatus(); |
298 | } |
299 | |
300 | Status SaveNodes() { |
301 | const char* sql = R"sql( |
302 | INSERT INTO Nodes ( |
303 | graph_id, |
304 | node_id, |
305 | node_name, |
306 | op, |
307 | device, |
308 | node_def) |
309 | VALUES (?, ?, ?, ?, ?, ?) |
310 | )sql" ; |
311 | SqliteStatement insert; |
312 | TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert)); |
313 | for (int node_id = 0; node_id < graph_->node_size(); ++node_id) { |
314 | NodeDef* node = graph_->mutable_node(node_id); |
315 | insert.BindInt(1, graph_id_); |
316 | insert.BindInt(2, node_id); |
317 | insert.BindText(3, node->name()); |
318 | insert.BindText(4, node->op()); |
319 | insert.BindText(5, node->device()); |
320 | node->clear_name(); |
321 | node->clear_op(); |
322 | node->clear_device(); |
323 | node->clear_input(); |
324 | string node_def; |
325 | if (node->SerializeToString(&node_def)) { |
326 | insert.BindBlobUnsafe(6, node_def); |
327 | } |
328 | unflushed_bytes_ += insert.size(); |
329 | TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), node->name()); |
330 | TF_RETURN_IF_ERROR(MaybeFlush()); |
331 | } |
332 | return OkStatus(); |
333 | } |
334 | |
335 | Status SaveGraph(int64_t run_id) { |
336 | const char* sql = R"sql( |
337 | INSERT OR REPLACE INTO Graphs ( |
338 | run_id, |
339 | graph_id, |
340 | inserted_time, |
341 | graph_def |
342 | ) VALUES (?, ?, ?, ?) |
343 | )sql" ; |
344 | SqliteStatement insert; |
345 | TF_RETURN_IF_ERROR(db_->Prepare(sql, &insert)); |
346 | if (run_id != kAbsent) insert.BindInt(1, run_id); |
347 | insert.BindInt(2, graph_id_); |
348 | insert.BindDouble(3, DoubleTime(now_)); |
349 | graph_->clear_node(); |
350 | string graph_def; |
351 | if (graph_->SerializeToString(&graph_def)) { |
352 | insert.BindBlobUnsafe(4, graph_def); |
353 | } |
354 | return insert.StepAndReset(); |
355 | } |
356 | |
357 | Status MaybeFlush() { |
358 | if (unflushed_bytes_ >= kFlushBytes) { |
359 | TF_RETURN_WITH_CONTEXT_IF_ERROR(txn_->Commit(), "flushing " , |
360 | unflushed_bytes_, " bytes" ); |
361 | unflushed_bytes_ = 0; |
362 | } |
363 | return OkStatus(); |
364 | } |
365 | |
366 | Sqlite* const db_; |
367 | SqliteTransaction* const txn_; |
368 | uint64 unflushed_bytes_ = 0; |
369 | GraphDef* const graph_; |
370 | const uint64 now_; |
371 | const int64_t graph_id_; |
372 | std::vector<string> name_copies_; |
373 | std::unordered_map<StringPiece, int64_t, StringPieceHasher> name_to_node_id_; |
374 | |
375 | TF_DISALLOW_COPY_AND_ASSIGN(GraphWriter); |
376 | }; |
377 | |
378 | /// \brief Run metadata manager. |
379 | /// |
380 | /// This class gives us Tag IDs we can pass to SeriesWriter. In order |
381 | /// to do that, rows are created in the Ids, Tags, Runs, Experiments, |
382 | /// and Users tables. |
383 | /// |
384 | /// This class is thread safe. |
385 | class RunMetadata { |
386 | public: |
387 | RunMetadata(IdAllocator* ids, const string& experiment_name, |
388 | const string& run_name, const string& user_name) |
389 | : ids_{ids}, |
390 | experiment_name_{experiment_name}, |
391 | run_name_{run_name}, |
392 | user_name_{user_name} { |
393 | DCHECK(ids_ != nullptr); |
394 | } |
395 | |
396 | const string& experiment_name() { return experiment_name_; } |
397 | const string& run_name() { return run_name_; } |
398 | const string& user_name() { return user_name_; } |
399 | |
400 | int64_t run_id() TF_LOCKS_EXCLUDED(mu_) { |
401 | mutex_lock lock(mu_); |
402 | return run_id_; |
403 | } |
404 | |
405 | Status SetGraph(Sqlite* db, uint64 now, double computed_time, |
406 | std::unique_ptr<GraphDef> g) SQLITE_TRANSACTIONS_EXCLUDED(*db) |
407 | TF_LOCKS_EXCLUDED(mu_) { |
408 | int64_t run_id; |
409 | { |
410 | mutex_lock lock(mu_); |
411 | TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time)); |
412 | run_id = run_id_; |
413 | } |
414 | int64_t graph_id; |
415 | SqliteTransaction txn(*db); // only to increase performance |
416 | TF_RETURN_IF_ERROR( |
417 | GraphWriter::Save(db, &txn, ids_, g.get(), now, run_id, &graph_id)); |
418 | return txn.Commit(); |
419 | } |
420 | |
421 | Status GetTagId(Sqlite* db, uint64 now, double computed_time, |
422 | const string& tag_name, int64_t* tag_id, |
423 | const SummaryMetadata& metadata) TF_LOCKS_EXCLUDED(mu_) { |
424 | mutex_lock lock(mu_); |
425 | TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time)); |
426 | auto e = tag_ids_.find(tag_name); |
427 | if (e != tag_ids_.end()) { |
428 | *tag_id = e->second; |
429 | return OkStatus(); |
430 | } |
431 | TF_RETURN_IF_ERROR(ids_->CreateNewId(tag_id)); |
432 | tag_ids_[tag_name] = *tag_id; |
433 | TF_RETURN_IF_ERROR( |
434 | SetDescription(db, *tag_id, metadata.summary_description())); |
435 | const char* sql = R"sql( |
436 | INSERT INTO Tags ( |
437 | run_id, |
438 | tag_id, |
439 | tag_name, |
440 | inserted_time, |
441 | display_name, |
442 | plugin_name, |
443 | plugin_data |
444 | ) VALUES ( |
445 | :run_id, |
446 | :tag_id, |
447 | :tag_name, |
448 | :inserted_time, |
449 | :display_name, |
450 | :plugin_name, |
451 | :plugin_data |
452 | ) |
453 | )sql" ; |
454 | SqliteStatement insert; |
455 | TF_RETURN_IF_ERROR(db->Prepare(sql, &insert)); |
456 | if (run_id_ != kAbsent) insert.BindInt(":run_id" , run_id_); |
457 | insert.BindInt(":tag_id" , *tag_id); |
458 | insert.BindTextUnsafe(":tag_name" , tag_name); |
459 | insert.BindDouble(":inserted_time" , DoubleTime(now)); |
460 | insert.BindTextUnsafe(":display_name" , metadata.display_name()); |
461 | insert.BindTextUnsafe(":plugin_name" , metadata.plugin_data().plugin_name()); |
462 | insert.BindBlobUnsafe(":plugin_data" , metadata.plugin_data().content()); |
463 | return insert.StepAndReset(); |
464 | } |
465 | |
466 | private: |
467 | Status InitializeUser(Sqlite* db, uint64 now) |
468 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
469 | if (user_id_ != kAbsent || user_name_.empty()) return OkStatus(); |
470 | const char* get_sql = R"sql( |
471 | SELECT user_id FROM Users WHERE user_name = ? |
472 | )sql" ; |
473 | SqliteStatement get; |
474 | TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get)); |
475 | get.BindText(1, user_name_); |
476 | bool is_done; |
477 | TF_RETURN_IF_ERROR(get.Step(&is_done)); |
478 | if (!is_done) { |
479 | user_id_ = get.ColumnInt(0); |
480 | return OkStatus(); |
481 | } |
482 | TF_RETURN_IF_ERROR(ids_->CreateNewId(&user_id_)); |
483 | const char* insert_sql = R"sql( |
484 | INSERT INTO Users ( |
485 | user_id, |
486 | user_name, |
487 | inserted_time |
488 | ) VALUES (?, ?, ?) |
489 | )sql" ; |
490 | SqliteStatement insert; |
491 | TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert)); |
492 | insert.BindInt(1, user_id_); |
493 | insert.BindText(2, user_name_); |
494 | insert.BindDouble(3, DoubleTime(now)); |
495 | TF_RETURN_IF_ERROR(insert.StepAndReset()); |
496 | return OkStatus(); |
497 | } |
498 | |
499 | Status InitializeExperiment(Sqlite* db, uint64 now, double computed_time) |
500 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
501 | if (experiment_name_.empty()) return OkStatus(); |
502 | if (experiment_id_ == kAbsent) { |
503 | TF_RETURN_IF_ERROR(InitializeUser(db, now)); |
504 | const char* get_sql = R"sql( |
505 | SELECT |
506 | experiment_id, |
507 | started_time |
508 | FROM |
509 | Experiments |
510 | WHERE |
511 | user_id IS ? |
512 | AND experiment_name = ? |
513 | )sql" ; |
514 | SqliteStatement get; |
515 | TF_RETURN_IF_ERROR(db->Prepare(get_sql, &get)); |
516 | if (user_id_ != kAbsent) get.BindInt(1, user_id_); |
517 | get.BindText(2, experiment_name_); |
518 | bool is_done; |
519 | TF_RETURN_IF_ERROR(get.Step(&is_done)); |
520 | if (!is_done) { |
521 | experiment_id_ = get.ColumnInt(0); |
522 | experiment_started_time_ = get.ColumnInt(1); |
523 | } else { |
524 | TF_RETURN_IF_ERROR(ids_->CreateNewId(&experiment_id_)); |
525 | experiment_started_time_ = computed_time; |
526 | const char* insert_sql = R"sql( |
527 | INSERT INTO Experiments ( |
528 | user_id, |
529 | experiment_id, |
530 | experiment_name, |
531 | inserted_time, |
532 | started_time, |
533 | is_watching |
534 | ) VALUES (?, ?, ?, ?, ?, ?) |
535 | )sql" ; |
536 | SqliteStatement insert; |
537 | TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert)); |
538 | if (user_id_ != kAbsent) insert.BindInt(1, user_id_); |
539 | insert.BindInt(2, experiment_id_); |
540 | insert.BindText(3, experiment_name_); |
541 | insert.BindDouble(4, DoubleTime(now)); |
542 | insert.BindDouble(5, computed_time); |
543 | insert.BindInt(6, 0); |
544 | TF_RETURN_IF_ERROR(insert.StepAndReset()); |
545 | } |
546 | } |
547 | if (computed_time < experiment_started_time_) { |
548 | experiment_started_time_ = computed_time; |
549 | const char* update_sql = R"sql( |
550 | UPDATE |
551 | Experiments |
552 | SET |
553 | started_time = ? |
554 | WHERE |
555 | experiment_id = ? |
556 | )sql" ; |
557 | SqliteStatement update; |
558 | TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update)); |
559 | update.BindDouble(1, computed_time); |
560 | update.BindInt(2, experiment_id_); |
561 | TF_RETURN_IF_ERROR(update.StepAndReset()); |
562 | } |
563 | return OkStatus(); |
564 | } |
565 | |
566 | Status InitializeRun(Sqlite* db, uint64 now, double computed_time) |
567 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
568 | if (run_name_.empty()) return OkStatus(); |
569 | TF_RETURN_IF_ERROR(InitializeExperiment(db, now, computed_time)); |
570 | if (run_id_ == kAbsent) { |
571 | TF_RETURN_IF_ERROR(ids_->CreateNewId(&run_id_)); |
572 | run_started_time_ = computed_time; |
573 | const char* insert_sql = R"sql( |
574 | INSERT OR REPLACE INTO Runs ( |
575 | experiment_id, |
576 | run_id, |
577 | run_name, |
578 | inserted_time, |
579 | started_time |
580 | ) VALUES (?, ?, ?, ?, ?) |
581 | )sql" ; |
582 | SqliteStatement insert; |
583 | TF_RETURN_IF_ERROR(db->Prepare(insert_sql, &insert)); |
584 | if (experiment_id_ != kAbsent) insert.BindInt(1, experiment_id_); |
585 | insert.BindInt(2, run_id_); |
586 | insert.BindText(3, run_name_); |
587 | insert.BindDouble(4, DoubleTime(now)); |
588 | insert.BindDouble(5, computed_time); |
589 | TF_RETURN_IF_ERROR(insert.StepAndReset()); |
590 | } |
591 | if (computed_time < run_started_time_) { |
592 | run_started_time_ = computed_time; |
593 | const char* update_sql = R"sql( |
594 | UPDATE |
595 | Runs |
596 | SET |
597 | started_time = ? |
598 | WHERE |
599 | run_id = ? |
600 | )sql" ; |
601 | SqliteStatement update; |
602 | TF_RETURN_IF_ERROR(db->Prepare(update_sql, &update)); |
603 | update.BindDouble(1, computed_time); |
604 | update.BindInt(2, run_id_); |
605 | TF_RETURN_IF_ERROR(update.StepAndReset()); |
606 | } |
607 | return OkStatus(); |
608 | } |
609 | |
610 | mutex mu_; |
611 | IdAllocator* const ids_; |
612 | const string experiment_name_; |
613 | const string run_name_; |
614 | const string user_name_; |
615 | int64_t experiment_id_ TF_GUARDED_BY(mu_) = kAbsent; |
616 | int64_t run_id_ TF_GUARDED_BY(mu_) = kAbsent; |
617 | int64_t user_id_ TF_GUARDED_BY(mu_) = kAbsent; |
618 | double experiment_started_time_ TF_GUARDED_BY(mu_) = 0.0; |
619 | double run_started_time_ TF_GUARDED_BY(mu_) = 0.0; |
620 | std::unordered_map<string, int64_t> tag_ids_ TF_GUARDED_BY(mu_); |
621 | |
622 | TF_DISALLOW_COPY_AND_ASSIGN(RunMetadata); |
623 | }; |
624 | |
625 | /// \brief Tensor writer for a single series, e.g. Tag. |
626 | /// |
627 | /// This class is thread safe. |
628 | class SeriesWriter { |
629 | public: |
630 | SeriesWriter(int64_t series, RunMetadata* meta) |
631 | : series_{series}, meta_{meta} { |
632 | DCHECK(series_ > 0); |
633 | } |
634 | |
635 | Status Append(Sqlite* db, int64_t step, uint64 now, double computed_time, |
636 | const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) |
637 | TF_LOCKS_EXCLUDED(mu_) { |
638 | mutex_lock lock(mu_); |
639 | if (rowids_.empty()) { |
640 | Status s = Reserve(db, t); |
641 | if (!s.ok()) { |
642 | rowids_.clear(); |
643 | return s; |
644 | } |
645 | } |
646 | int64_t rowid = rowids_.front(); |
647 | Status s = Write(db, rowid, step, computed_time, t); |
648 | if (s.ok()) { |
649 | ++count_; |
650 | } |
651 | rowids_.pop_front(); |
652 | return s; |
653 | } |
654 | |
655 | Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) |
656 | TF_LOCKS_EXCLUDED(mu_) { |
657 | mutex_lock lock(mu_); |
658 | // Delete unused pre-allocated Tensors. |
659 | if (!rowids_.empty()) { |
660 | SqliteTransaction txn(*db); |
661 | const char* sql = R"sql( |
662 | DELETE FROM Tensors WHERE rowid = ? |
663 | )sql" ; |
664 | SqliteStatement deleter; |
665 | TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter)); |
666 | for (size_t i = count_; i < rowids_.size(); ++i) { |
667 | deleter.BindInt(1, rowids_.front()); |
668 | TF_RETURN_IF_ERROR(deleter.StepAndReset()); |
669 | rowids_.pop_front(); |
670 | } |
671 | TF_RETURN_IF_ERROR(txn.Commit()); |
672 | rowids_.clear(); |
673 | } |
674 | return OkStatus(); |
675 | } |
676 | |
677 | private: |
678 | Status Write(Sqlite* db, int64_t rowid, int64_t step, double computed_time, |
679 | const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) { |
680 | if (t.dtype() == DT_STRING) { |
681 | if (t.dims() == 0) { |
682 | return Update(db, step, computed_time, t, t.scalar<tstring>()(), rowid); |
683 | } else { |
684 | SqliteTransaction txn(*db); |
685 | TF_RETURN_IF_ERROR( |
686 | Update(db, step, computed_time, t, StringPiece(), rowid)); |
687 | TF_RETURN_IF_ERROR(UpdateNdString(db, t, rowid)); |
688 | return txn.Commit(); |
689 | } |
690 | } else { |
691 | return Update(db, step, computed_time, t, t.tensor_data(), rowid); |
692 | } |
693 | } |
694 | |
695 | Status Update(Sqlite* db, int64_t step, double computed_time, const Tensor& t, |
696 | const StringPiece& data, int64_t rowid) { |
697 | const char* sql = R"sql( |
698 | UPDATE OR REPLACE |
699 | Tensors |
700 | SET |
701 | step = ?, |
702 | computed_time = ?, |
703 | dtype = ?, |
704 | shape = ?, |
705 | data = ? |
706 | WHERE |
707 | rowid = ? |
708 | )sql" ; |
709 | SqliteStatement stmt; |
710 | TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt)); |
711 | stmt.BindInt(1, step); |
712 | stmt.BindDouble(2, computed_time); |
713 | stmt.BindInt(3, t.dtype()); |
714 | stmt.BindText(4, StringifyShape(t.shape())); |
715 | stmt.BindBlobUnsafe(5, data); |
716 | stmt.BindInt(6, rowid); |
717 | TF_RETURN_IF_ERROR(stmt.StepAndReset()); |
718 | return OkStatus(); |
719 | } |
720 | |
721 | Status UpdateNdString(Sqlite* db, const Tensor& t, int64_t tensor_rowid) |
722 | SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) { |
723 | DCHECK_EQ(t.dtype(), DT_STRING); |
724 | DCHECK_GT(t.dims(), 0); |
725 | const char* deleter_sql = R"sql( |
726 | DELETE FROM TensorStrings WHERE tensor_rowid = ? |
727 | )sql" ; |
728 | SqliteStatement deleter; |
729 | TF_RETURN_IF_ERROR(db->Prepare(deleter_sql, &deleter)); |
730 | deleter.BindInt(1, tensor_rowid); |
731 | TF_RETURN_WITH_CONTEXT_IF_ERROR(deleter.StepAndReset(), tensor_rowid); |
732 | const char* inserter_sql = R"sql( |
733 | INSERT INTO TensorStrings ( |
734 | tensor_rowid, |
735 | idx, |
736 | data |
737 | ) VALUES (?, ?, ?) |
738 | )sql" ; |
739 | SqliteStatement inserter; |
740 | TF_RETURN_IF_ERROR(db->Prepare(inserter_sql, &inserter)); |
741 | auto flat = t.flat<tstring>(); |
742 | for (int64_t i = 0; i < flat.size(); ++i) { |
743 | inserter.BindInt(1, tensor_rowid); |
744 | inserter.BindInt(2, i); |
745 | inserter.BindBlobUnsafe(3, flat(i)); |
746 | TF_RETURN_WITH_CONTEXT_IF_ERROR(inserter.StepAndReset(), "i=" , i); |
747 | } |
748 | return OkStatus(); |
749 | } |
750 | |
751 | Status Reserve(Sqlite* db, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) |
752 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
753 | SqliteTransaction txn(*db); // only for performance |
754 | unflushed_bytes_ = 0; |
755 | if (t.dtype() == DT_STRING) { |
756 | if (t.dims() == 0) { |
757 | TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.scalar<tstring>()().size())); |
758 | } else { |
759 | TF_RETURN_IF_ERROR(ReserveTensors(db, &txn, kReserveMinBytes)); |
760 | } |
761 | } else { |
762 | TF_RETURN_IF_ERROR(ReserveData(db, &txn, t.tensor_data().size())); |
763 | } |
764 | return txn.Commit(); |
765 | } |
766 | |
767 | Status ReserveData(Sqlite* db, SqliteTransaction* txn, size_t size) |
768 | SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) |
769 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
770 | int64_t space = |
771 | static_cast<int64_t>(static_cast<double>(size) * kReserveMultiplier); |
772 | if (space < kReserveMinBytes) space = kReserveMinBytes; |
773 | return ReserveTensors(db, txn, space); |
774 | } |
775 | |
776 | Status ReserveTensors(Sqlite* db, SqliteTransaction* txn, |
777 | int64_t reserved_bytes) |
778 | SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) |
779 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
780 | const char* sql = R"sql( |
781 | INSERT INTO Tensors ( |
782 | series, |
783 | data |
784 | ) VALUES (?, ZEROBLOB(?)) |
785 | )sql" ; |
786 | SqliteStatement insert; |
787 | TF_RETURN_IF_ERROR(db->Prepare(sql, &insert)); |
788 | // TODO(jart): Maybe preallocate index pages by setting step. This |
789 | // is tricky because UPDATE OR REPLACE can have a side |
790 | // effect of deleting preallocated rows. |
791 | for (int64_t i = 0; i < kPreallocateRows; ++i) { |
792 | insert.BindInt(1, series_); |
793 | insert.BindInt(2, reserved_bytes); |
794 | TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=" , i); |
795 | rowids_.push_back(db->last_insert_rowid()); |
796 | unflushed_bytes_ += reserved_bytes; |
797 | TF_RETURN_IF_ERROR(MaybeFlush(db, txn)); |
798 | } |
799 | return OkStatus(); |
800 | } |
801 | |
802 | Status MaybeFlush(Sqlite* db, SqliteTransaction* txn) |
803 | SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) |
804 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { |
805 | if (unflushed_bytes_ >= kFlushBytes) { |
806 | TF_RETURN_WITH_CONTEXT_IF_ERROR(txn->Commit(), "flushing " , |
807 | unflushed_bytes_, " bytes" ); |
808 | unflushed_bytes_ = 0; |
809 | } |
810 | return OkStatus(); |
811 | } |
812 | |
813 | mutex mu_; |
814 | const int64_t series_; |
815 | RunMetadata* const meta_; |
816 | uint64 count_ TF_GUARDED_BY(mu_) = 0; |
817 | std::deque<int64_t> rowids_ TF_GUARDED_BY(mu_); |
818 | uint64 unflushed_bytes_ TF_GUARDED_BY(mu_) = 0; |
819 | |
820 | TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter); |
821 | }; |
822 | |
823 | /// \brief Tensor writer for a single Run. |
824 | /// |
825 | /// This class farms out tensors to SeriesWriter instances. It also |
826 | /// keeps track of whether or not someone is watching the TensorBoard |
827 | /// GUI, so it can avoid writes when possible. |
828 | /// |
829 | /// This class is thread safe. |
830 | class RunWriter { |
831 | public: |
832 | explicit RunWriter(RunMetadata* meta) : meta_{meta} {} |
833 | |
834 | Status Append(Sqlite* db, int64_t tag_id, int64_t step, uint64 now, |
835 | double computed_time, const Tensor& t) |
836 | SQLITE_TRANSACTIONS_EXCLUDED(*db) TF_LOCKS_EXCLUDED(mu_) { |
837 | SeriesWriter* writer = GetSeriesWriter(tag_id); |
838 | return writer->Append(db, step, now, computed_time, t); |
839 | } |
840 | |
841 | Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) |
842 | TF_LOCKS_EXCLUDED(mu_) { |
843 | mutex_lock lock(mu_); |
844 | if (series_writers_.empty()) return OkStatus(); |
845 | for (auto i = series_writers_.begin(); i != series_writers_.end(); ++i) { |
846 | if (!i->second) continue; |
847 | TF_RETURN_WITH_CONTEXT_IF_ERROR(i->second->Finish(db), |
848 | "finish tag_id=" , i->first); |
849 | i->second.reset(); |
850 | } |
851 | return OkStatus(); |
852 | } |
853 | |
854 | private: |
855 | SeriesWriter* GetSeriesWriter(int64_t tag_id) TF_LOCKS_EXCLUDED(mu_) { |
856 | mutex_lock sl(mu_); |
857 | auto spot = series_writers_.find(tag_id); |
858 | if (spot == series_writers_.end()) { |
859 | SeriesWriter* writer = new SeriesWriter(tag_id, meta_); |
860 | series_writers_[tag_id].reset(writer); |
861 | return writer; |
862 | } else { |
863 | return spot->second.get(); |
864 | } |
865 | } |
866 | |
867 | mutex mu_; |
868 | RunMetadata* const meta_; |
869 | std::unordered_map<int64_t, std::unique_ptr<SeriesWriter>> series_writers_ |
870 | TF_GUARDED_BY(mu_); |
871 | |
872 | TF_DISALLOW_COPY_AND_ASSIGN(RunWriter); |
873 | }; |
874 | |
875 | /// \brief SQLite implementation of SummaryWriterInterface. |
876 | /// |
877 | /// This class is thread safe. |
878 | class SummaryDbWriter : public SummaryWriterInterface { |
879 | public: |
880 | SummaryDbWriter(Env* env, Sqlite* db, const string& experiment_name, |
881 | const string& run_name, const string& user_name) |
882 | : SummaryWriterInterface(), |
883 | env_{env}, |
884 | db_{db}, |
885 | ids_{env_, db_}, |
886 | meta_{&ids_, experiment_name, run_name, user_name}, |
887 | run_{&meta_} { |
888 | DCHECK(env_ != nullptr); |
889 | db_->Ref(); |
890 | } |
891 | |
892 | ~SummaryDbWriter() override { |
893 | core::ScopedUnref unref(db_); |
894 | Status s = run_.Finish(db_); |
895 | if (!s.ok()) { |
896 | // TODO(jart): Retry on transient errors here. |
897 | LOG(ERROR) << s.ToString(); |
898 | } |
899 | int64_t run_id = meta_.run_id(); |
900 | if (run_id == kAbsent) return; |
901 | const char* sql = R"sql( |
902 | UPDATE Runs SET finished_time = ? WHERE run_id = ? |
903 | )sql" ; |
904 | SqliteStatement update; |
905 | s = db_->Prepare(sql, &update); |
906 | if (s.ok()) { |
907 | update.BindDouble(1, DoubleTime(env_->NowMicros())); |
908 | update.BindInt(2, run_id); |
909 | s = update.StepAndReset(); |
910 | } |
911 | if (!s.ok()) { |
912 | LOG(ERROR) << "Failed to set Runs[" << run_id |
913 | << "].finish_time: " << s.ToString(); |
914 | } |
915 | } |
916 | |
917 | Status Flush() override { return OkStatus(); } |
918 | |
919 | Status WriteTensor(int64_t global_step, Tensor t, const string& tag, |
920 | const string& serialized_metadata) override { |
921 | TF_RETURN_IF_ERROR(CheckSupportedType(t)); |
922 | SummaryMetadata metadata; |
923 | if (!metadata.ParseFromString(serialized_metadata)) { |
924 | return errors::InvalidArgument("Bad serialized_metadata" ); |
925 | } |
926 | return Write(global_step, t, tag, metadata); |
927 | } |
928 | |
929 | Status WriteScalar(int64_t global_step, Tensor t, |
930 | const string& tag) override { |
931 | TF_RETURN_IF_ERROR(CheckSupportedType(t)); |
932 | SummaryMetadata metadata; |
933 | PatchPluginName(&metadata, kScalarPluginName); |
934 | return Write(global_step, AsScalar(t), tag, metadata); |
935 | } |
936 | |
937 | Status WriteGraph(int64_t global_step, std::unique_ptr<GraphDef> g) override { |
938 | uint64 now = env_->NowMicros(); |
939 | return meta_.SetGraph(db_, now, DoubleTime(now), std::move(g)); |
940 | } |
941 | |
942 | Status WriteEvent(std::unique_ptr<Event> e) override { |
943 | return MigrateEvent(std::move(e)); |
944 | } |
945 | |
946 | Status WriteHistogram(int64_t global_step, Tensor t, |
947 | const string& tag) override { |
948 | uint64 now = env_->NowMicros(); |
949 | std::unique_ptr<Event> e{new Event}; |
950 | e->set_step(global_step); |
951 | e->set_wall_time(DoubleTime(now)); |
952 | TF_RETURN_IF_ERROR( |
953 | AddTensorAsHistogramToSummary(t, tag, e->mutable_summary())); |
954 | return MigrateEvent(std::move(e)); |
955 | } |
956 | |
957 | Status WriteImage(int64_t global_step, Tensor t, const string& tag, |
958 | int max_images, Tensor bad_color) override { |
959 | uint64 now = env_->NowMicros(); |
960 | std::unique_ptr<Event> e{new Event}; |
961 | e->set_step(global_step); |
962 | e->set_wall_time(DoubleTime(now)); |
963 | TF_RETURN_IF_ERROR(AddTensorAsImageToSummary(t, tag, max_images, bad_color, |
964 | e->mutable_summary())); |
965 | return MigrateEvent(std::move(e)); |
966 | } |
967 | |
968 | Status WriteAudio(int64_t global_step, Tensor t, const string& tag, |
969 | int max_outputs, float sample_rate) override { |
970 | uint64 now = env_->NowMicros(); |
971 | std::unique_ptr<Event> e{new Event}; |
972 | e->set_step(global_step); |
973 | e->set_wall_time(DoubleTime(now)); |
974 | TF_RETURN_IF_ERROR(AddTensorAsAudioToSummary( |
975 | t, tag, max_outputs, sample_rate, e->mutable_summary())); |
976 | return MigrateEvent(std::move(e)); |
977 | } |
978 | |
979 | string DebugString() const override { return "SummaryDbWriter" ; } |
980 | |
981 | private: |
982 | Status Write(int64_t step, const Tensor& t, const string& tag, |
983 | const SummaryMetadata& metadata) { |
984 | uint64 now = env_->NowMicros(); |
985 | double computed_time = DoubleTime(now); |
986 | int64_t tag_id; |
987 | TF_RETURN_IF_ERROR( |
988 | meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata)); |
989 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
990 | run_.Append(db_, tag_id, step, now, computed_time, t), |
991 | meta_.user_name(), "/" , meta_.experiment_name(), "/" , meta_.run_name(), |
992 | "/" , tag, "@" , step); |
993 | return OkStatus(); |
994 | } |
995 | |
996 | Status MigrateEvent(std::unique_ptr<Event> e) { |
997 | switch (e->what_case()) { |
998 | case Event::WhatCase::kSummary: { |
999 | uint64 now = env_->NowMicros(); |
1000 | auto summaries = e->mutable_summary(); |
1001 | for (int i = 0; i < summaries->value_size(); ++i) { |
1002 | Summary::Value* value = summaries->mutable_value(i); |
1003 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
1004 | MigrateSummary(e.get(), value, now), meta_.user_name(), "/" , |
1005 | meta_.experiment_name(), "/" , meta_.run_name(), "/" , value->tag(), |
1006 | "@" , e->step()); |
1007 | } |
1008 | break; |
1009 | } |
1010 | case Event::WhatCase::kGraphDef: |
1011 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
1012 | MigrateGraph(e.get(), e->graph_def()), meta_.user_name(), "/" , |
1013 | meta_.experiment_name(), "/" , meta_.run_name(), "/__graph__@" , |
1014 | e->step()); |
1015 | break; |
1016 | default: |
1017 | // TODO(@jart): Handle other stuff. |
1018 | break; |
1019 | } |
1020 | return OkStatus(); |
1021 | } |
1022 | |
1023 | Status MigrateGraph(const Event* e, const string& graph_def) { |
1024 | uint64 now = env_->NowMicros(); |
1025 | std::unique_ptr<GraphDef> graph{new GraphDef}; |
1026 | if (!ParseProtoUnlimited(graph.get(), graph_def)) { |
1027 | return errors::InvalidArgument("bad proto" ); |
1028 | } |
1029 | return meta_.SetGraph(db_, now, e->wall_time(), std::move(graph)); |
1030 | } |
1031 | |
1032 | Status MigrateSummary(const Event* e, Summary::Value* s, uint64 now) { |
1033 | switch (s->value_case()) { |
1034 | case Summary::Value::ValueCase::kTensor: |
1035 | TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateTensor(e, s, now), "tensor" ); |
1036 | break; |
1037 | case Summary::Value::ValueCase::kSimpleValue: |
1038 | TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateScalar(e, s, now), "scalar" ); |
1039 | break; |
1040 | case Summary::Value::ValueCase::kHisto: |
1041 | TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateHistogram(e, s, now), "histo" ); |
1042 | break; |
1043 | case Summary::Value::ValueCase::kImage: |
1044 | TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateImage(e, s, now), "image" ); |
1045 | break; |
1046 | case Summary::Value::ValueCase::kAudio: |
1047 | TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateAudio(e, s, now), "audio" ); |
1048 | break; |
1049 | default: |
1050 | break; |
1051 | } |
1052 | return OkStatus(); |
1053 | } |
1054 | |
1055 | Status MigrateTensor(const Event* e, Summary::Value* s, uint64 now) { |
1056 | Tensor t; |
1057 | if (!t.FromProto(s->tensor())) return errors::InvalidArgument("bad proto" ); |
1058 | TF_RETURN_IF_ERROR(CheckSupportedType(t)); |
1059 | int64_t tag_id; |
1060 | TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), |
1061 | &tag_id, s->metadata())); |
1062 | return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); |
1063 | } |
1064 | |
1065 | // TODO(jart): Refactor Summary -> Tensor logic into separate file. |
1066 | |
1067 | Status MigrateScalar(const Event* e, Summary::Value* s, uint64 now) { |
1068 | // See tensorboard/plugins/scalar/summary.py and data_compat.py |
1069 | Tensor t{DT_FLOAT, {}}; |
1070 | t.scalar<float>()() = s->simple_value(); |
1071 | int64_t tag_id; |
1072 | PatchPluginName(s->mutable_metadata(), kScalarPluginName); |
1073 | TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), |
1074 | &tag_id, s->metadata())); |
1075 | return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); |
1076 | } |
1077 | |
1078 | Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) { |
1079 | const HistogramProto& histo = s->histo(); |
1080 | int k = histo.bucket_size(); |
1081 | if (k != histo.bucket_limit_size()) { |
1082 | return errors::InvalidArgument("size mismatch" ); |
1083 | } |
1084 | // See tensorboard/plugins/histogram/summary.py and data_compat.py |
1085 | Tensor t{DT_DOUBLE, {k, 3}}; |
1086 | auto data = t.flat<double>(); |
1087 | for (int i = 0, j = 0; i < k; ++i) { |
1088 | // TODO(nickfelt): reconcile with TensorBoard's data_compat.py |
1089 | // From summary.proto |
1090 | // Parallel arrays encoding the bucket boundaries and the bucket values. |
1091 | // bucket(i) is the count for the bucket i. The range for |
1092 | // a bucket is: |
1093 | // i == 0: -DBL_MAX .. bucket_limit(0) |
1094 | // i != 0: bucket_limit(i-1) .. bucket_limit(i) |
1095 | double left_edge = (i == 0) ? std::numeric_limits<double>::min() |
1096 | : histo.bucket_limit(i - 1); |
1097 | |
1098 | data(j++) = left_edge; |
1099 | data(j++) = histo.bucket_limit(i); |
1100 | data(j++) = histo.bucket(i); |
1101 | } |
1102 | int64_t tag_id; |
1103 | PatchPluginName(s->mutable_metadata(), kHistogramPluginName); |
1104 | TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), |
1105 | &tag_id, s->metadata())); |
1106 | return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); |
1107 | } |
1108 | |
1109 | Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) { |
1110 | // See tensorboard/plugins/image/summary.py and data_compat.py |
1111 | Tensor t{DT_STRING, {3}}; |
1112 | auto img = s->mutable_image(); |
1113 | t.flat<tstring>()(0) = strings::StrCat(img->width()); |
1114 | t.flat<tstring>()(1) = strings::StrCat(img->height()); |
1115 | t.flat<tstring>()(2) = std::move(*img->mutable_encoded_image_string()); |
1116 | int64_t tag_id; |
1117 | PatchPluginName(s->mutable_metadata(), kImagePluginName); |
1118 | TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), |
1119 | &tag_id, s->metadata())); |
1120 | return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); |
1121 | } |
1122 | |
1123 | Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) { |
1124 | // See tensorboard/plugins/audio/summary.py and data_compat.py |
1125 | Tensor t{DT_STRING, {1, 2}}; |
1126 | auto wav = s->mutable_audio(); |
1127 | t.flat<tstring>()(0) = std::move(*wav->mutable_encoded_audio_string()); |
1128 | t.flat<tstring>()(1) = "" ; |
1129 | int64_t tag_id; |
1130 | PatchPluginName(s->mutable_metadata(), kAudioPluginName); |
1131 | TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), |
1132 | &tag_id, s->metadata())); |
1133 | return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); |
1134 | } |
1135 | |
1136 | Env* const env_; |
1137 | Sqlite* const db_; |
1138 | IdAllocator ids_; |
1139 | RunMetadata meta_; |
1140 | RunWriter run_; |
1141 | }; |
1142 | |
1143 | } // namespace |
1144 | |
1145 | Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name, |
1146 | const string& run_name, const string& user_name, |
1147 | Env* env, SummaryWriterInterface** result) { |
1148 | *result = new SummaryDbWriter(env, db, experiment_name, run_name, user_name); |
1149 | return OkStatus(); |
1150 | } |
1151 | |
1152 | } // namespace tensorflow |
1153 | |