1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/data/snapshot_utils.h"
17
18#include <algorithm>
19#include <functional>
20#include <queue>
21#include <string>
22#include <utility>
23
24#include "absl/memory/memory.h"
25#include "tensorflow/core/common_runtime/dma_helper.h"
26#include "tensorflow/core/data/name_utils.h"
27#include "tensorflow/core/framework/dataset.h"
28#include "tensorflow/core/framework/graph.pb.h"
29#include "tensorflow/core/framework/tensor.pb.h"
30#include "tensorflow/core/lib/io/buffered_inputstream.h"
31#include "tensorflow/core/lib/io/random_inputstream.h"
32#include "tensorflow/core/lib/io/record_writer.h"
33#include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h"
34#include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h"
35#include "tensorflow/core/lib/io/zlib_compression_options.h"
36#include "tensorflow/core/lib/io/zlib_inputstream.h"
37#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
38#include "tensorflow/core/platform/coding.h"
39#include "tensorflow/core/platform/errors.h"
40#include "tensorflow/core/platform/file_system.h"
41#include "tensorflow/core/platform/path.h"
42#include "tensorflow/core/platform/random.h"
43#include "tensorflow/core/platform/strcat.h"
44#include "tensorflow/core/platform/stringprintf.h"
45#include "tensorflow/core/profiler/lib/traceme.h"
46#include "tensorflow/core/protobuf/snapshot.pb.h"
47
48namespace tensorflow {
49namespace data {
50namespace snapshot_util {
51namespace {
52
53constexpr const char* const kOutputTypes = "output_types";
54constexpr const char* const kOutputShapes = "output_shapes";
55constexpr const char* const kCompression = "compression";
56constexpr const char* const kVersion = "version";
57constexpr const char* const kCurrentCheckpointID = "current_checkpoint_id";
58constexpr const char* const kIndex = "index";
59constexpr const char* const kStartIndex = "start_index";
60
61} // namespace
62
63/* static */ constexpr const int64_t
64 CustomReader::kSnappyReaderInputBufferSizeBytes;
65/* static */ constexpr const int64_t
66 CustomReader::kSnappyReaderOutputBufferSizeBytes;
67
68std::string HashDirectory(const std::string& path, uint64 hash) {
69 return io::JoinPath(
70 path, strings::Printf("%llu", static_cast<unsigned long long>(hash)));
71}
72
73std::string RunDirectory(const std::string& hash_directory, uint64 run_id) {
74 return RunDirectory(
75 hash_directory,
76 strings::Printf("%llu", static_cast<unsigned long long>(run_id)));
77}
78
79std::string RunDirectory(const std::string& hash_directory,
80 const std::string& run_id) {
81 return io::JoinPath(hash_directory, run_id);
82}
83
84std::string ShardDirectory(const std::string& run_directory, int64_t shard_id) {
85 return io::JoinPath(
86 run_directory,
87 strings::Printf("%08llu%s", static_cast<unsigned long long>(shard_id),
88 kShardDirectorySuffix));
89}
90std::string GetCheckpointFileName(const std::string& shard_directory,
91 uint64 checkpoint_id) {
92 return io::JoinPath(
93 shard_directory,
94 strings::Printf("%08llu.snapshot",
95 static_cast<unsigned long long>(checkpoint_id)));
96}
97
98Status Writer::Create(Env* env, const std::string& filename,
99 const std::string& compression_type, int version,
100 const DataTypeVector& dtypes,
101 std::unique_ptr<Writer>* out_writer) {
102 switch (version) {
103 case 1:
104 *out_writer =
105 std::make_unique<CustomWriter>(filename, compression_type, dtypes);
106 break;
107 case 2:
108 *out_writer =
109 std::make_unique<TFRecordWriter>(filename, compression_type);
110 break;
111 default:
112 return errors::InvalidArgument("Snapshot writer version: ", version,
113 " is not supported.");
114 }
115
116 return (*out_writer)->Initialize(env);
117}
118
119TFRecordWriter::TFRecordWriter(const std::string& filename,
120 const std::string& compression_type)
121 : filename_(filename), compression_type_(compression_type) {}
122
123Status TFRecordWriter::Initialize(tensorflow::Env* env) {
124 TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_));
125
126 record_writer_ = std::make_unique<io::RecordWriter>(
127 dest_.get(), io::RecordWriterOptions::CreateRecordWriterOptions(
128 /*compression_type=*/compression_type_));
129 return OkStatus();
130}
131
132Status TFRecordWriter::WriteTensors(const std::vector<Tensor>& tensors) {
133 for (const auto& tensor : tensors) {
134 TensorProto proto;
135 tensor.AsProtoTensorContent(&proto);
136#if defined(TF_CORD_SUPPORT)
137 // Creating raw pointer here because std::move() in a releases in OSS TF
138 // will result in a smart pointer being moved upon function creation, which
139 // will result in proto_buffer == nullptr when WriteRecord happens.
140 auto proto_buffer = new std::string();
141 proto.SerializeToString(proto_buffer);
142 absl::Cord proto_serialized = absl::MakeCordFromExternal(
143 *proto_buffer,
144 [proto_buffer](absl::string_view) { delete proto_buffer; });
145 TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized));
146#else // TF_CORD_SUPPORT
147 TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsString()));
148#endif // TF_CORD_SUPPORT
149 }
150 return OkStatus();
151}
152
153Status TFRecordWriter::Sync() {
154 TF_RETURN_IF_ERROR(record_writer_->Flush());
155 return dest_->Flush();
156}
157
158Status TFRecordWriter::Close() {
159 if (record_writer_ != nullptr) {
160 TF_RETURN_IF_ERROR(Sync());
161 TF_RETURN_IF_ERROR(record_writer_->Close());
162 TF_RETURN_IF_ERROR(dest_->Close());
163 record_writer_ = nullptr;
164 dest_ = nullptr;
165 }
166 return OkStatus();
167}
168
169TFRecordWriter::~TFRecordWriter() {
170 Status s = Close();
171 if (!s.ok()) {
172 LOG(ERROR) << "Failed to close snapshot file " << filename_ << ": " << s;
173 }
174}
175
176CustomWriter::CustomWriter(const std::string& filename,
177 const std::string& compression_type,
178 const DataTypeVector& dtypes)
179 : filename_(filename),
180 compression_type_(compression_type),
181 dtypes_(dtypes) {}
182
183Status CustomWriter::Initialize(tensorflow::Env* env) {
184 TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_));
185#if defined(IS_SLIM_BUILD)
186 if (compression_type_ != io::compression::kNone) {
187 LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
188 << "off compression.";
189 }
190#else // IS_SLIM_BUILD
191 if (compression_type_ == io::compression::kGzip) {
192 zlib_underlying_dest_.swap(dest_);
193 io::ZlibCompressionOptions zlib_options;
194 zlib_options = io::ZlibCompressionOptions::GZIP();
195
196 io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
197 zlib_underlying_dest_.get(), zlib_options.input_buffer_size,
198 zlib_options.output_buffer_size, zlib_options);
199 TF_CHECK_OK(zlib_output_buffer->Init());
200 dest_.reset(zlib_output_buffer);
201 }
202#endif // IS_SLIM_BUILD
203 simple_tensor_mask_.reserve(dtypes_.size());
204 for (const auto& dtype : dtypes_) {
205 if (DataTypeCanUseMemcpy(dtype)) {
206 simple_tensor_mask_.push_back(true);
207 num_simple_++;
208 } else {
209 simple_tensor_mask_.push_back(false);
210 num_complex_++;
211 }
212 }
213
214 return OkStatus();
215}
216
217Status CustomWriter::WriteTensors(const std::vector<Tensor>& tensors) {
218 if (compression_type_ != io::compression::kSnappy) {
219 experimental::SnapshotRecord record;
220 for (const auto& tensor : tensors) {
221 TensorProto* t = record.add_tensor();
222 tensor.AsProtoTensorContent(t);
223 }
224#if defined(TF_CORD_SUPPORT)
225 auto record_buffer = new std::string();
226 record.SerializeToString(record_buffer);
227 absl::Cord record_serialized = absl::MakeCordFromExternal(
228 *record_buffer,
229 [record_buffer](absl::string_view) { delete record_buffer; });
230 return WriteRecord(record_serialized);
231#else // TF_CORD_SUPPORT
232 return WriteRecord(record.SerializeAsString());
233#endif // TF_CORD_SUPPORT
234 }
235
236 std::vector<const TensorBuffer*> tensor_buffers;
237 tensor_buffers.reserve(num_simple_);
238 std::vector<TensorProto> tensor_protos;
239 tensor_protos.reserve(num_complex_);
240 experimental::SnapshotTensorMetadata metadata;
241 int64_t total_size = 0;
242 for (int i = 0, end = tensors.size(); i < end; ++i) {
243 const Tensor& tensor = tensors[i];
244 experimental::TensorMetadata* tensor_metadata =
245 metadata.add_tensor_metadata();
246 tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape());
247 int64_t size = 0;
248 if (simple_tensor_mask_[i]) {
249 auto tensor_buffer = DMAHelper::buffer(&tensor);
250 tensor_buffers.push_back(tensor_buffer);
251 size = tensor_buffer->size();
252 } else {
253 TensorProto proto;
254 tensor.AsProtoTensorContent(&proto);
255 size = proto.ByteSizeLong();
256 tensor_protos.push_back(std::move(proto));
257 }
258 tensor_metadata->set_tensor_size_bytes(size);
259 total_size += size;
260 }
261
262 std::vector<char> uncompressed(total_size);
263 char* position = uncompressed.data();
264 int buffer_index = 0;
265 int proto_index = 0;
266 for (int i = 0, end = tensors.size(); i < end; ++i) {
267 const auto& tensor_metadata = metadata.tensor_metadata(i);
268 if (simple_tensor_mask_[i]) {
269 memcpy(position, tensor_buffers[buffer_index]->data(),
270 tensor_metadata.tensor_size_bytes());
271 buffer_index++;
272 } else {
273 tensor_protos[proto_index].SerializeToArray(
274 position, tensor_metadata.tensor_size_bytes());
275 proto_index++;
276 }
277 position += tensor_metadata.tensor_size_bytes();
278 }
279 DCHECK_EQ(position, uncompressed.data() + total_size);
280
281 string output;
282 if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) {
283 return errors::Internal("Failed to compress using snappy.");
284 }
285
286#if defined(TF_CORD_SUPPORT)
287 auto metadata_buffer = new std::string();
288 metadata.SerializeToString(metadata_buffer);
289 absl::Cord metadata_serialized = absl::MakeCordFromExternal(
290 *metadata_buffer,
291 [metadata_buffer](absl::string_view) { delete metadata_buffer; });
292#else
293 std::string metadata_serialized = metadata.SerializeAsString();
294#endif // TF_CORD_SUPPORT
295 TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized));
296 TF_RETURN_IF_ERROR(WriteRecord(output));
297 return OkStatus();
298}
299
300Status CustomWriter::Sync() { return dest_->Sync(); }
301
302Status CustomWriter::Close() {
303 if (dest_ != nullptr) {
304 TF_RETURN_IF_ERROR(dest_->Close());
305 dest_ = nullptr;
306 }
307 if (zlib_underlying_dest_ != nullptr) {
308 TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close());
309 zlib_underlying_dest_ = nullptr;
310 }
311 return OkStatus();
312}
313
314CustomWriter::~CustomWriter() {
315 Status s = Close();
316 if (!s.ok()) {
317 LOG(ERROR) << "Could not finish writing file: " << s;
318 }
319}
320
321Status CustomWriter::WriteRecord(const StringPiece& data) {
322 char header[kHeaderSize];
323 core::EncodeFixed64(header, data.size());
324 TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
325 return dest_->Append(data);
326}
327
328#if defined(TF_CORD_SUPPORT)
329Status CustomWriter::WriteRecord(const absl::Cord& data) {
330 char header[kHeaderSize];
331 core::EncodeFixed64(header, data.size());
332 TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
333 return dest_->Append(data);
334}
335#endif // TF_CORD_SUPPORT
336
337Status Reader::Create(Env* env, const std::string& filename,
338 const string& compression_type, int version,
339 const DataTypeVector& dtypes,
340 std::unique_ptr<Reader>* out_reader) {
341 switch (version) {
342 // CustomReader is able to read a legacy snapshot file format (v0) though
343 // custom writer doesn't have the ability to write it any more since it is
344 // strictly worse than V1.
345 case 0:
346 case 1:
347 *out_reader = std::make_unique<CustomReader>(filename, compression_type,
348 version, dtypes);
349 break;
350 case 2:
351 *out_reader =
352 std::make_unique<TFRecordReader>(filename, compression_type, dtypes);
353 break;
354 default:
355 return errors::InvalidArgument("Snapshot reader version: ", version,
356 " is not supported.");
357 }
358
359 return (*out_reader)->Initialize(env);
360}
361
362Status Reader::SkipRecords(int64_t num_records) {
363 // TODO(frankchn): Optimize to not parse the entire Tensor and actually skip.
364 for (int i = 0; i < num_records; ++i) {
365 std::vector<Tensor> unused_tensors;
366 TF_RETURN_IF_ERROR(ReadTensors(&unused_tensors));
367 }
368 return OkStatus();
369}
370
371class Reader::Dataset : public DatasetBase {
372 public:
373 Dataset(DatasetContext&& ctx, const std::string& shard_dir,
374 const std::string& compression, const int64_t version,
375 const DataTypeVector& dtypes,
376 const std::vector<PartialTensorShape>& shapes,
377 const int64_t start_index)
378 : DatasetBase(std::move(ctx)),
379 shard_dir_(shard_dir),
380 compression_(compression),
381 version_(version),
382 dtypes_(dtypes),
383 shapes_(shapes),
384 start_index_(start_index) {}
385
386 const DataTypeVector& output_dtypes() const override { return dtypes_; }
387
388 const std::vector<PartialTensorShape>& output_shapes() const override {
389 return shapes_;
390 }
391
392 std::string DebugString() const override { return "SnapshotDatasetReader"; }
393
394 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
395 return OkStatus();
396 }
397
398 Status CheckExternalState() const override { return OkStatus(); }
399
400 protected:
401 Status AsGraphDefInternal(SerializationContext* ctx,
402 DatasetGraphDefBuilder* b,
403 Node** node) const override {
404 Node* shard_dir = nullptr;
405 TF_RETURN_IF_ERROR(b->AddScalar(shard_dir_, &shard_dir));
406
407 Node* start_index = nullptr;
408 TF_RETURN_IF_ERROR(b->AddScalar(start_index_, &start_index));
409
410 AttrValue compression;
411 b->BuildAttrValue(compression_, &compression);
412
413 AttrValue version;
414 b->BuildAttrValue(version_, &version);
415
416 return b->AddDataset(
417 this,
418 /*inputs=*/
419 {std::make_pair(0, shard_dir), std::make_pair(1, start_index)},
420 /*list_inputs=*/{},
421 /*attrs=*/
422 {{kCompression, compression}, {kVersion, version}},
423 /*use_dataset_name=*/true, node);
424 }
425
426 std::unique_ptr<IteratorBase> MakeIteratorInternal(
427 const string& prefix) const override {
428 return std::make_unique<Iterator>(Iterator::Params{
429 this, name_utils::IteratorPrefix(node_name(), prefix)});
430 }
431
432 private:
433 class Iterator : public DatasetIterator<Dataset> {
434 public:
435 explicit Iterator(const Params& params)
436 : DatasetIterator<Dataset>(params),
437 start_index_(dataset()->start_index_) {}
438
439 Status Initialize(IteratorContext* ctx) override {
440 // TODO(jsimsa): This only needs to happen when we are not restoring but
441 // parallel_interleave op implementation caches IteratorContext (and thus
442 // the is_restoring bit ends up being inaccurate).
443 TF_RETURN_IF_ERROR(Reader::Create(
444 ctx->env(), GetCurrentFilename(), dataset()->compression_,
445 dataset()->version_, dataset()->dtypes_, &reader_));
446 return AdvanceToStartIndex(ctx);
447 }
448
449 protected:
450 Status GetNextInternal(IteratorContext* ctx,
451 std::vector<Tensor>* out_tensors,
452 bool* end_of_sequence) override {
453 *end_of_sequence = false;
454 Status s = reader_->ReadTensors(out_tensors);
455 if (!errors::IsOutOfRange(s)) {
456 start_index_++;
457 return s;
458 }
459 Status status = AdvanceToNextFile(ctx->env());
460 if (errors::IsNotFound(status)) {
461 *end_of_sequence = true;
462 return OkStatus();
463 }
464 return status;
465 }
466
467 Status SaveInternal(SerializationContext* ctx,
468 IteratorStateWriter* writer) override {
469 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentCheckpointID),
470 current_checkpoint_id_));
471 TF_RETURN_IF_ERROR(
472 writer->WriteScalar(full_name(kStartIndex), start_index_));
473 return OkStatus();
474 }
475
476 Status RestoreInternal(IteratorContext* ctx,
477 IteratorStateReader* reader) override {
478 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurrentCheckpointID),
479 &current_checkpoint_id_));
480 TF_RETURN_IF_ERROR(
481 reader->ReadScalar(full_name(kStartIndex), &start_index_));
482 TF_RETURN_IF_ERROR(ctx->env()->FileExists(GetCurrentFilename()));
483 TF_RETURN_IF_ERROR(Reader::Create(
484 ctx->env(), GetCurrentFilename(), dataset()->compression_,
485 dataset()->version_, dataset()->dtypes_, &reader_));
486 return AdvanceToStartIndex(ctx);
487 }
488
489 private:
490 Status AdvanceToNextFile(Env* env) {
491 start_index_ = 0;
492 current_checkpoint_id_++;
493 TF_RETURN_IF_ERROR(env->FileExists(GetCurrentFilename()));
494 return Reader::Create(env, GetCurrentFilename(), dataset()->compression_,
495 dataset()->version_, dataset()->dtypes_, &reader_);
496 }
497
498 std::string GetCurrentFilename() {
499 return GetCheckpointFileName(dataset()->shard_dir_,
500 current_checkpoint_id_);
501 }
502
503 // TODO(frankchn): Optimize this to not parse every single element.
504 Status AdvanceToStartIndex(IteratorContext* ctx) {
505 for (int64_t i = 0; i < start_index_; ++i) {
506 std::vector<Tensor> unused;
507 TF_RETURN_IF_ERROR(reader_->ReadTensors(&unused));
508 }
509 return OkStatus();
510 }
511
512 std::unique_ptr<Reader> reader_;
513
514 // Stores the id current checkpoint file that we are in the process of
515 // reading (e.g. if the file is currently 00000001.snapshot, then this will
516 // be 1).
517 int64_t current_checkpoint_id_ = 0;
518 int64_t start_index_;
519 };
520
521 const tstring shard_dir_;
522 const std::string compression_;
523 const int64_t version_;
524 const DataTypeVector dtypes_;
525 const std::vector<PartialTensorShape> shapes_;
526 const int64_t start_index_;
527};
528
529Reader::DatasetOp::DatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
530 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
531 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
532 OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_));
533 OP_REQUIRES_OK(ctx, ctx->GetAttr(kVersion, &version_));
534}
535
536void Reader::DatasetOp::MakeDataset(OpKernelContext* ctx,
537 DatasetBase** output) {
538 tstring shard_dir;
539 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "shard_dir", &shard_dir));
540
541 int64_t start_index;
542 OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "start_index", &start_index));
543
544 *output =
545 new Reader::Dataset(DatasetContext(ctx), shard_dir, compression_,
546 version_, output_types_, output_shapes_, start_index);
547}
548
549class Reader::NestedDataset : public DatasetBase {
550 public:
551 explicit NestedDataset(DatasetContext&& ctx,
552 std::vector<DatasetBase*> datasets)
553 : DatasetBase(std::move(ctx)), datasets_(datasets) {
554 dtypes_.push_back(DT_VARIANT);
555 gtl::InlinedVector<int64_t, 1> element_dim_sizes;
556 element_dim_sizes.push_back(1);
557 partial_shapes_.emplace_back(element_dim_sizes);
558 }
559
560 const DataTypeVector& output_dtypes() const override { return dtypes_; }
561
562 const std::vector<PartialTensorShape>& output_shapes() const override {
563 return partial_shapes_;
564 }
565
566 std::string DebugString() const override {
567 return "SnapshotNestedDatasetReader";
568 }
569
570 Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
571 inputs->clear();
572 return OkStatus();
573 }
574
575 Status CheckExternalState() const override { return OkStatus(); }
576
577 protected:
578 Status AsGraphDefInternal(SerializationContext* ctx,
579 DatasetGraphDefBuilder* b,
580 Node** node) const override {
581 std::vector<Node*> input_graph_nodes;
582 input_graph_nodes.reserve(datasets_.size());
583 for (const auto& dataset : datasets_) {
584 Node* input_node;
585 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, dataset, &input_node));
586 input_graph_nodes.emplace_back(input_node);
587 }
588 TF_RETURN_IF_ERROR(
589 b->AddDataset(this, /*inputs=*/{},
590 /*list_inputs=*/{std::make_pair(0, input_graph_nodes)},
591 /*attrs=*/{}, node));
592 return OkStatus();
593 }
594
595 std::unique_ptr<IteratorBase> MakeIteratorInternal(
596 const string& prefix) const override {
597 return std::make_unique<Iterator>(Iterator::Params{
598 this, name_utils::IteratorPrefix(node_name(), prefix)});
599 }
600
601 private:
602 std::vector<DatasetBase*> datasets_;
603 DataTypeVector dtypes_;
604 std::vector<PartialTensorShape> partial_shapes_;
605
606 class Iterator : public DatasetIterator<NestedDataset> {
607 public:
608 explicit Iterator(const Params& params)
609 : DatasetIterator<NestedDataset>(params) {}
610
611 protected:
612 Status GetNextInternal(IteratorContext* ctx,
613 std::vector<Tensor>* out_tensors,
614 bool* end_of_sequence) override {
615 const int64_t num_datasets = dataset()->datasets_.size();
616 *end_of_sequence = num_datasets == index_;
617 if (!*end_of_sequence) {
618 Tensor tensor(DT_VARIANT, TensorShape({}));
619
620 TF_RETURN_IF_ERROR(
621 StoreDatasetInVariantTensor(dataset()->datasets_[index_], &tensor));
622 out_tensors->clear();
623 out_tensors->push_back(std::move(tensor));
624
625 index_++;
626 }
627 return OkStatus();
628 }
629
630 Status SaveInternal(SerializationContext* ctx,
631 IteratorStateWriter* writer) override {
632 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
633 return OkStatus();
634 }
635
636 Status RestoreInternal(IteratorContext* ctx,
637 IteratorStateReader* reader) override {
638 TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &index_));
639 return OkStatus();
640 }
641
642 private:
643 int64_t index_ = 0;
644 };
645};
646
647Reader::NestedDatasetOp::NestedDatasetOp(OpKernelConstruction* ctx)
648 : DatasetOpKernel(ctx) {
649 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
650 OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
651}
652
653void Reader::NestedDatasetOp::MakeDataset(OpKernelContext* ctx,
654 DatasetBase** output) {
655 std::vector<DatasetBase*> inputs;
656 for (size_t i = 0; i < ctx->num_inputs(); ++i) {
657 DatasetBase* input;
658 OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
659 inputs.push_back(input);
660 }
661 *output = new Reader::NestedDataset(DatasetContext(ctx), inputs);
662 (*output)->Initialize(/*metadata=*/{});
663}
664
665Status Reader::MakeNestedDataset(Env* env,
666 const std::vector<std::string>& shard_dirs,
667 const string& compression_type, int version,
668 const DataTypeVector& dtypes,
669 const std::vector<PartialTensorShape>& shapes,
670 const int64_t start_index,
671 DatasetBase** output) {
672 std::vector<DatasetBase*> datasets;
673
674 datasets.reserve(shard_dirs.size());
675 for (int64_t i = 0; i < shard_dirs.size(); ++i) {
676 // TODO(frankchn): The reading pattern could be controlled in a non-round
677 // robin fashion, so we cannot assume a round-robin manner when restoring.
678 int64_t dataset_start_index = start_index / shard_dirs.size();
679 if (start_index % shard_dirs.size() > datasets.size()) {
680 dataset_start_index++;
681 }
682
683 datasets.push_back(
684 new Dataset(DatasetContext(DatasetContext::Params(
685 {"SnapshotDatasetReader",
686 strings::StrCat("SnapshotDatasetReader/_", i)})),
687 shard_dirs.at(i), compression_type, version, dtypes, shapes,
688 dataset_start_index));
689 datasets.back()->Initialize(/*metadata=*/{});
690 }
691
692 // Rotate the vector such that the first dataset contains the next element
693 // to be produced, but not if there are no shards at all (then we just
694 // construct an empty dataset).
695 if (!shard_dirs.empty()) {
696 std::rotate(datasets.begin(),
697 datasets.begin() + (start_index % shard_dirs.size()),
698 datasets.end());
699 }
700
701 *output = new NestedDataset(
702 DatasetContext(DatasetContext::Params(
703 {"SnapshotNestedDatasetReader", "SnapshotNestedDatasetReader"})),
704 datasets);
705 (*output)->Initialize(/*metadata=*/{});
706 return OkStatus();
707}
708
709TFRecordReader::TFRecordReader(const std::string& filename,
710 const string& compression_type,
711 const DataTypeVector& dtypes)
712 : filename_(filename),
713 offset_(0),
714 compression_type_(compression_type),
715 dtypes_(dtypes) {}
716
717Status TFRecordReader::Initialize(Env* env) {
718 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_));
719
720 record_reader_ = std::make_unique<io::RecordReader>(
721 file_.get(), io::RecordReaderOptions::CreateRecordReaderOptions(
722 /*compression_type=*/compression_type_));
723 return OkStatus();
724}
725
726Status TFRecordReader::ReadTensors(std::vector<Tensor>* read_tensors) {
727 read_tensors->reserve(dtypes_.size());
728 for (int i = 0; i < dtypes_.size(); ++i) {
729 tstring record;
730 TF_RETURN_IF_ERROR(record_reader_->ReadRecord(&offset_, &record));
731
732 TensorProto proto;
733 proto.ParseFromArray(record.data(), record.size());
734
735 Tensor tensor;
736 if (!tensor.FromProto(proto)) {
737 return errors::DataLoss("Unable to parse tensor from stored proto.");
738 }
739
740 read_tensors->push_back(std::move(tensor));
741 }
742 return OkStatus();
743}
744
745CustomReader::CustomReader(const std::string& filename,
746 const string& compression_type, const int version,
747 const DataTypeVector& dtypes)
748 : filename_(filename),
749 compression_type_(compression_type),
750 version_(version),
751 dtypes_(dtypes) {}
752
753Status CustomReader::Initialize(Env* env) {
754 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_));
755 input_stream_ = std::make_unique<io::RandomAccessInputStream>(file_.get());
756
757#if defined(IS_SLIM_BUILD)
758 if (compression_type_ != io::compression::kNone) {
759 LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning "
760 << "off compression.";
761 }
762#else // IS_SLIM_BUILD
763 if (compression_type_ == io::compression::kGzip) {
764 io::ZlibCompressionOptions zlib_options;
765 zlib_options = io::ZlibCompressionOptions::GZIP();
766
767 input_stream_ = std::make_unique<io::ZlibInputStream>(
768 input_stream_.release(), zlib_options.input_buffer_size,
769 zlib_options.output_buffer_size, zlib_options, true);
770 } else if (compression_type_ == io::compression::kSnappy) {
771 if (version_ == 0) {
772 input_stream_ = std::make_unique<io::SnappyInputBuffer>(
773 file_.get(), /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes,
774 /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes);
775 } else {
776 input_stream_ =
777 std::make_unique<io::BufferedInputStream>(file_.get(), 64 << 20);
778 }
779 }
780#endif // IS_SLIM_BUILD
781 simple_tensor_mask_.reserve(dtypes_.size());
782 for (const auto& dtype : dtypes_) {
783 if (DataTypeCanUseMemcpy(dtype)) {
784 simple_tensor_mask_.push_back(true);
785 num_simple_++;
786 } else {
787 simple_tensor_mask_.push_back(false);
788 num_complex_++;
789 }
790 }
791
792 return OkStatus();
793}
794
795Status CustomReader::ReadTensors(std::vector<Tensor>* read_tensors) {
796 profiler::TraceMe activity(
797 [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); },
798 profiler::TraceMeLevel::kInfo);
799 if (version_ == 0 || compression_type_ != io::compression::kSnappy) {
800 return ReadTensorsV0(read_tensors);
801 }
802 if (version_ != 1) {
803 return errors::InvalidArgument("Version: ", version_, " is not supported.");
804 }
805 if (compression_type_ != io::compression::kSnappy) {
806 return errors::InvalidArgument("Compression ", compression_type_,
807 " is not supported.");
808 }
809
810 experimental::SnapshotTensorMetadata metadata;
811 tstring metadata_str;
812 TF_RETURN_IF_ERROR(ReadRecord(&metadata_str));
813 if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) {
814 return errors::DataLoss("Could not parse SnapshotTensorMetadata");
815 }
816 read_tensors->reserve(metadata.tensor_metadata_size());
817
818 std::vector<Tensor> simple_tensors;
819 simple_tensors.reserve(num_simple_);
820 std::vector<std::pair<std::unique_ptr<char[]>, size_t>> tensor_proto_strs;
821 tensor_proto_strs.reserve(num_complex_);
822 TF_RETURN_IF_ERROR(
823 SnappyUncompress(&metadata, &simple_tensors, &tensor_proto_strs));
824
825 int simple_index = 0;
826 int complex_index = 0;
827 for (int i = 0, end = simple_tensor_mask_.size(); i < end; ++i) {
828 if (simple_tensor_mask_[i]) {
829 read_tensors->push_back(std::move(simple_tensors[simple_index]));
830 simple_index++;
831 } else {
832 auto tensor_proto_str = std::move(tensor_proto_strs[complex_index].first);
833 size_t tensor_proto_size = tensor_proto_strs[complex_index].second;
834 TensorProto tp;
835 if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) {
836 return errors::Internal("Could not parse TensorProto");
837 }
838 Tensor t;
839 if (!t.FromProto(tp)) {
840 return errors::Internal("Could not parse Tensor");
841 }
842 read_tensors->push_back(std::move(t));
843 complex_index++;
844 }
845 }
846 return OkStatus();
847}
848
849Status CustomReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) {
850 experimental::SnapshotRecord record;
851#if defined(PLATFORM_GOOGLE)
852 absl::Cord c;
853 TF_RETURN_IF_ERROR(ReadRecord(&c));
854 record.ParseFromCord(c);
855#else // PLATFORM_GOOGLE
856 tstring record_bytes;
857 TF_RETURN_IF_ERROR(ReadRecord(&record_bytes));
858 record.ParseFromArray(record_bytes.data(), record_bytes.size());
859#endif // PLATFORM_GOOGLE
860 read_tensors->reserve(record.tensor_size());
861 for (int i = 0; i < record.tensor_size(); ++i) {
862 read_tensors->emplace_back();
863 if (!read_tensors->back().FromProto(record.tensor(i))) {
864 return errors::DataLoss("Unable to parse tensor from proto.");
865 }
866 }
867 return OkStatus();
868}
869
870Status CustomReader::SnappyUncompress(
871 const experimental::SnapshotTensorMetadata* metadata,
872 std::vector<Tensor>* simple_tensors,
873 std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
874 tensor_proto_strs) {
875 tstring compressed;
876 TF_RETURN_IF_ERROR(ReadRecord(&compressed));
877 size_t size;
878 if (!port::Snappy_GetUncompressedLength(compressed.data(), compressed.size(),
879 &size)) {
880 return errors::Internal("Could not get snappy uncompressed length");
881 }
882
883 int num_tensors = metadata->tensor_metadata_size();
884 std::vector<struct iovec> iov(num_tensors);
885 int index = 0;
886 int64_t total_size = 0;
887 for (int i = 0, end = simple_tensor_mask_.size(); i < end; ++i) {
888 const auto& tensor_metadata = metadata->tensor_metadata(i);
889 if (simple_tensor_mask_[i]) {
890 TensorShape shape(tensor_metadata.tensor_shape());
891 Tensor simple_tensor(dtypes_[i], shape);
892 TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor);
893 iov[index].iov_base = buffer->data();
894 iov[index].iov_len = buffer->size();
895 simple_tensors->push_back(std::move(simple_tensor));
896 } else {
897 auto tensor_proto_str =
898 std::make_unique<char[]>(tensor_metadata.tensor_size_bytes());
899 iov[index].iov_base = tensor_proto_str.get();
900 iov[index].iov_len = tensor_metadata.tensor_size_bytes();
901 tensor_proto_strs->push_back(std::make_pair(
902 std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes()));
903 }
904 total_size += iov[index].iov_len;
905 index++;
906 }
907 const int64_t size_int = size;
908 if (size_int != total_size) {
909 return errors::Internal("Uncompressed size mismatch. Snappy expects ", size,
910 " whereas the tensor metadata suggests ",
911 total_size);
912 }
913 if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(),
914 iov.data(), num_tensors)) {
915 return errors::Internal("Failed to perform snappy decompression.");
916 }
917 return OkStatus();
918}
919
920Status CustomReader::ReadRecord(tstring* record) {
921 tstring header;
922 TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
923 uint64 length = core::DecodeFixed64(header.data());
924 return input_stream_->ReadNBytes(length, record);
925}
926
927#if defined(TF_CORD_SUPPORT)
928Status CustomReader::ReadRecord(absl::Cord* record) {
929 tstring header;
930 TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header));
931 uint64 length = core::DecodeFixed64(header.data());
932 if (compression_type_ == io::compression::kNone) {
933 return input_stream_->ReadNBytes(length, record);
934 } else {
935 auto tmp_str = new tstring();
936 TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str));
937 absl::string_view tmp_str_view(*tmp_str);
938 record->Append(absl::MakeCordFromExternal(
939 tmp_str_view, [tmp_str](absl::string_view) { delete tmp_str; }));
940 return OkStatus();
941 }
942}
943#endif // TF_CORD_SUPPORT
944
945Status WriteMetadataFile(Env* env, const string& dir,
946 const experimental::SnapshotMetadataRecord* metadata) {
947 string metadata_filename = io::JoinPath(dir, kMetadataFilename);
948 TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(dir));
949 std::string tmp_filename =
950 absl::StrCat(metadata_filename, "-tmp-", random::New64());
951 TF_RETURN_IF_ERROR(WriteBinaryProto(env, tmp_filename, *metadata));
952 return env->RenameFile(tmp_filename, metadata_filename);
953}
954
955Status ReadMetadataFile(Env* env, const string& dir,
956 experimental::SnapshotMetadataRecord* metadata,
957 bool* file_exists) {
958 string metadata_filename = io::JoinPath(dir, kMetadataFilename);
959 Status s = env->FileExists(metadata_filename);
960 *file_exists = s.ok();
961
962 if (*file_exists) {
963 return ReadBinaryProto(env, metadata_filename, metadata);
964 } else {
965 return OkStatus();
966 }
967}
968
969Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash,
970 const GraphDef* graph) {
971 std::string hash_hex =
972 strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
973 std::string graph_file =
974 io::JoinPath(path, absl::StrCat(hash_hex, "-graph.pbtxt"));
975
976 LOG(INFO) << "Graph hash is " << hash_hex << ", writing to " << graph_file;
977 TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(path));
978 return WriteTextProto(env, graph_file, *graph);
979}
980
981Status DetermineOpState(const std::string& mode_string, bool file_exists,
982 const experimental::SnapshotMetadataRecord* metadata,
983 const uint64 pending_snapshot_expiry_seconds,
984 Mode* mode) {
985 if (mode_string == kModeRead) {
986 // In read mode, we should expect a metadata file is written.
987 if (!file_exists) {
988 return errors::NotFound("Metadata file does not exist.");
989 }
990 LOG(INFO) << "Overriding mode to reader.";
991 *mode = READER;
992 return OkStatus();
993 }
994
995 if (mode_string == kModeWrite) {
996 LOG(INFO) << "Overriding mode to writer.";
997 *mode = WRITER;
998 return OkStatus();
999 }
1000
1001 if (mode_string == kModePassthrough) {
1002 LOG(INFO) << "Overriding mode to passthrough.";
1003 *mode = PASSTHROUGH;
1004 return OkStatus();
1005 }
1006
1007 if (!file_exists) {
1008 *mode = WRITER;
1009 return OkStatus();
1010 }
1011
1012 if (metadata->finalized()) {
1013 // File found, snapshot has been finalized.
1014 *mode = READER;
1015 return OkStatus();
1016 }
1017
1018 int64_t expiration_timer = static_cast<int64_t>(EnvTime::NowMicros()) -
1019 pending_snapshot_expiry_seconds * 1000000;
1020
1021 if (metadata->creation_timestamp() >= expiration_timer) {
1022 // Someone else is already writing and time has not expired.
1023 *mode = PASSTHROUGH;
1024 return OkStatus();
1025 } else {
1026 // Time has expired, we write regardless.
1027 *mode = WRITER;
1028 return OkStatus();
1029 }
1030}
1031
1032AsyncWriter::AsyncWriter(Env* env, int64_t file_index,
1033 const std::string& shard_directory,
1034 uint64 checkpoint_id, const std::string& compression,
1035 int64_t version, const DataTypeVector& output_types,
1036 std::function<void(Status)> done) {
1037 thread_ = absl::WrapUnique(env->StartThread(
1038 ThreadOptions(), absl::StrCat("writer_thread_", file_index),
1039 [this, env, shard_directory, checkpoint_id, compression, version,
1040 &output_types, done = std::move(done)] {
1041 done(WriterThread(env, shard_directory, checkpoint_id, compression,
1042 version, output_types));
1043 }));
1044}
1045
1046void AsyncWriter::Write(const std::vector<Tensor>& tensors) {
1047 mutex_lock l(mu_);
1048 ElementOrEOF element;
1049 element.value = tensors;
1050 deque_.push_back(std::move(element));
1051}
1052
1053void AsyncWriter::SignalEOF() {
1054 mutex_lock l(mu_);
1055 ElementOrEOF be;
1056 be.end_of_sequence = true;
1057 deque_.push_back(std::move(be));
1058}
1059
1060void AsyncWriter::Consume(ElementOrEOF* be) {
1061 mutex_lock l(mu_);
1062 mu_.Await(tensorflow::Condition(this, &AsyncWriter::ElementAvailable));
1063 *be = deque_.front();
1064 deque_.pop_front();
1065}
1066
1067bool AsyncWriter::ElementAvailable() { return !deque_.empty(); }
1068
1069Status AsyncWriter::WriterThread(Env* env, const std::string& shard_directory,
1070 uint64 checkpoint_id,
1071 const std::string& compression,
1072 int64_t version, DataTypeVector output_types) {
1073 std::unique_ptr<snapshot_util::Writer> writer;
1074 TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory));
1075
1076 TF_RETURN_IF_ERROR(snapshot_util::Writer::Create(
1077 env, GetCheckpointFileName(shard_directory, checkpoint_id), compression,
1078 version, std::move(output_types), &writer));
1079
1080 while (true) {
1081 ElementOrEOF be;
1082 Consume(&be);
1083
1084 if (be.end_of_sequence) {
1085 TF_RETURN_IF_ERROR(writer->Close());
1086 break;
1087 }
1088
1089 TF_RETURN_IF_ERROR(writer->WriteTensors(be.value));
1090 }
1091 return OkStatus();
1092}
1093
1094namespace {
1095
1096REGISTER_KERNEL_BUILDER(Name("SnapshotDatasetReader").Device(DEVICE_CPU),
1097 Reader::DatasetOp);
1098REGISTER_KERNEL_BUILDER(Name("SnapshotNestedDatasetReader").Device(DEVICE_CPU),
1099 Reader::NestedDatasetOp);
1100
1101} // namespace
1102} // namespace snapshot_util
1103} // namespace data
1104} // namespace tensorflow
1105