1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/data/snapshot_utils.h" |
17 | |
18 | #include <algorithm> |
19 | #include <functional> |
20 | #include <queue> |
21 | #include <string> |
22 | #include <utility> |
23 | |
24 | #include "absl/memory/memory.h" |
25 | #include "tensorflow/core/common_runtime/dma_helper.h" |
26 | #include "tensorflow/core/data/name_utils.h" |
27 | #include "tensorflow/core/framework/dataset.h" |
28 | #include "tensorflow/core/framework/graph.pb.h" |
29 | #include "tensorflow/core/framework/tensor.pb.h" |
30 | #include "tensorflow/core/lib/io/buffered_inputstream.h" |
31 | #include "tensorflow/core/lib/io/random_inputstream.h" |
32 | #include "tensorflow/core/lib/io/record_writer.h" |
33 | #include "tensorflow/core/lib/io/snappy/snappy_inputbuffer.h" |
34 | #include "tensorflow/core/lib/io/snappy/snappy_outputbuffer.h" |
35 | #include "tensorflow/core/lib/io/zlib_compression_options.h" |
36 | #include "tensorflow/core/lib/io/zlib_inputstream.h" |
37 | #include "tensorflow/core/lib/io/zlib_outputbuffer.h" |
38 | #include "tensorflow/core/platform/coding.h" |
39 | #include "tensorflow/core/platform/errors.h" |
40 | #include "tensorflow/core/platform/file_system.h" |
41 | #include "tensorflow/core/platform/path.h" |
42 | #include "tensorflow/core/platform/random.h" |
43 | #include "tensorflow/core/platform/strcat.h" |
44 | #include "tensorflow/core/platform/stringprintf.h" |
45 | #include "tensorflow/core/profiler/lib/traceme.h" |
46 | #include "tensorflow/core/protobuf/snapshot.pb.h" |
47 | |
48 | namespace tensorflow { |
49 | namespace data { |
50 | namespace snapshot_util { |
51 | namespace { |
52 | |
53 | constexpr const char* const kOutputTypes = "output_types" ; |
54 | constexpr const char* const kOutputShapes = "output_shapes" ; |
55 | constexpr const char* const kCompression = "compression" ; |
56 | constexpr const char* const kVersion = "version" ; |
57 | constexpr const char* const kCurrentCheckpointID = "current_checkpoint_id" ; |
58 | constexpr const char* const kIndex = "index" ; |
59 | constexpr const char* const kStartIndex = "start_index" ; |
60 | |
61 | } // namespace |
62 | |
63 | /* static */ constexpr const int64_t |
64 | CustomReader::kSnappyReaderInputBufferSizeBytes; |
65 | /* static */ constexpr const int64_t |
66 | CustomReader::kSnappyReaderOutputBufferSizeBytes; |
67 | |
68 | std::string HashDirectory(const std::string& path, uint64 hash) { |
69 | return io::JoinPath( |
70 | path, strings::Printf("%llu" , static_cast<unsigned long long>(hash))); |
71 | } |
72 | |
73 | std::string RunDirectory(const std::string& hash_directory, uint64 run_id) { |
74 | return RunDirectory( |
75 | hash_directory, |
76 | strings::Printf("%llu" , static_cast<unsigned long long>(run_id))); |
77 | } |
78 | |
79 | std::string RunDirectory(const std::string& hash_directory, |
80 | const std::string& run_id) { |
81 | return io::JoinPath(hash_directory, run_id); |
82 | } |
83 | |
84 | std::string ShardDirectory(const std::string& run_directory, int64_t shard_id) { |
85 | return io::JoinPath( |
86 | run_directory, |
87 | strings::Printf("%08llu%s" , static_cast<unsigned long long>(shard_id), |
88 | kShardDirectorySuffix)); |
89 | } |
90 | std::string GetCheckpointFileName(const std::string& shard_directory, |
91 | uint64 checkpoint_id) { |
92 | return io::JoinPath( |
93 | shard_directory, |
94 | strings::Printf("%08llu.snapshot" , |
95 | static_cast<unsigned long long>(checkpoint_id))); |
96 | } |
97 | |
98 | Status Writer::Create(Env* env, const std::string& filename, |
99 | const std::string& compression_type, int version, |
100 | const DataTypeVector& dtypes, |
101 | std::unique_ptr<Writer>* out_writer) { |
102 | switch (version) { |
103 | case 1: |
104 | *out_writer = |
105 | std::make_unique<CustomWriter>(filename, compression_type, dtypes); |
106 | break; |
107 | case 2: |
108 | *out_writer = |
109 | std::make_unique<TFRecordWriter>(filename, compression_type); |
110 | break; |
111 | default: |
112 | return errors::InvalidArgument("Snapshot writer version: " , version, |
113 | " is not supported." ); |
114 | } |
115 | |
116 | return (*out_writer)->Initialize(env); |
117 | } |
118 | |
119 | TFRecordWriter::TFRecordWriter(const std::string& filename, |
120 | const std::string& compression_type) |
121 | : filename_(filename), compression_type_(compression_type) {} |
122 | |
123 | Status TFRecordWriter::Initialize(tensorflow::Env* env) { |
124 | TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_)); |
125 | |
126 | record_writer_ = std::make_unique<io::RecordWriter>( |
127 | dest_.get(), io::RecordWriterOptions::CreateRecordWriterOptions( |
128 | /*compression_type=*/compression_type_)); |
129 | return OkStatus(); |
130 | } |
131 | |
132 | Status TFRecordWriter::WriteTensors(const std::vector<Tensor>& tensors) { |
133 | for (const auto& tensor : tensors) { |
134 | TensorProto proto; |
135 | tensor.AsProtoTensorContent(&proto); |
136 | #if defined(TF_CORD_SUPPORT) |
137 | // Creating raw pointer here because std::move() in a releases in OSS TF |
138 | // will result in a smart pointer being moved upon function creation, which |
139 | // will result in proto_buffer == nullptr when WriteRecord happens. |
140 | auto proto_buffer = new std::string(); |
141 | proto.SerializeToString(proto_buffer); |
142 | absl::Cord proto_serialized = absl::MakeCordFromExternal( |
143 | *proto_buffer, |
144 | [proto_buffer](absl::string_view) { delete proto_buffer; }); |
145 | TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto_serialized)); |
146 | #else // TF_CORD_SUPPORT |
147 | TF_RETURN_IF_ERROR(record_writer_->WriteRecord(proto.SerializeAsString())); |
148 | #endif // TF_CORD_SUPPORT |
149 | } |
150 | return OkStatus(); |
151 | } |
152 | |
153 | Status TFRecordWriter::Sync() { |
154 | TF_RETURN_IF_ERROR(record_writer_->Flush()); |
155 | return dest_->Flush(); |
156 | } |
157 | |
158 | Status TFRecordWriter::Close() { |
159 | if (record_writer_ != nullptr) { |
160 | TF_RETURN_IF_ERROR(Sync()); |
161 | TF_RETURN_IF_ERROR(record_writer_->Close()); |
162 | TF_RETURN_IF_ERROR(dest_->Close()); |
163 | record_writer_ = nullptr; |
164 | dest_ = nullptr; |
165 | } |
166 | return OkStatus(); |
167 | } |
168 | |
169 | TFRecordWriter::~TFRecordWriter() { |
170 | Status s = Close(); |
171 | if (!s.ok()) { |
172 | LOG(ERROR) << "Failed to close snapshot file " << filename_ << ": " << s; |
173 | } |
174 | } |
175 | |
176 | CustomWriter::CustomWriter(const std::string& filename, |
177 | const std::string& compression_type, |
178 | const DataTypeVector& dtypes) |
179 | : filename_(filename), |
180 | compression_type_(compression_type), |
181 | dtypes_(dtypes) {} |
182 | |
183 | Status CustomWriter::Initialize(tensorflow::Env* env) { |
184 | TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_)); |
185 | #if defined(IS_SLIM_BUILD) |
186 | if (compression_type_ != io::compression::kNone) { |
187 | LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " |
188 | << "off compression." ; |
189 | } |
190 | #else // IS_SLIM_BUILD |
191 | if (compression_type_ == io::compression::kGzip) { |
192 | zlib_underlying_dest_.swap(dest_); |
193 | io::ZlibCompressionOptions zlib_options; |
194 | zlib_options = io::ZlibCompressionOptions::GZIP(); |
195 | |
196 | io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer( |
197 | zlib_underlying_dest_.get(), zlib_options.input_buffer_size, |
198 | zlib_options.output_buffer_size, zlib_options); |
199 | TF_CHECK_OK(zlib_output_buffer->Init()); |
200 | dest_.reset(zlib_output_buffer); |
201 | } |
202 | #endif // IS_SLIM_BUILD |
203 | simple_tensor_mask_.reserve(dtypes_.size()); |
204 | for (const auto& dtype : dtypes_) { |
205 | if (DataTypeCanUseMemcpy(dtype)) { |
206 | simple_tensor_mask_.push_back(true); |
207 | num_simple_++; |
208 | } else { |
209 | simple_tensor_mask_.push_back(false); |
210 | num_complex_++; |
211 | } |
212 | } |
213 | |
214 | return OkStatus(); |
215 | } |
216 | |
217 | Status CustomWriter::WriteTensors(const std::vector<Tensor>& tensors) { |
218 | if (compression_type_ != io::compression::kSnappy) { |
219 | experimental::SnapshotRecord record; |
220 | for (const auto& tensor : tensors) { |
221 | TensorProto* t = record.add_tensor(); |
222 | tensor.AsProtoTensorContent(t); |
223 | } |
224 | #if defined(TF_CORD_SUPPORT) |
225 | auto record_buffer = new std::string(); |
226 | record.SerializeToString(record_buffer); |
227 | absl::Cord record_serialized = absl::MakeCordFromExternal( |
228 | *record_buffer, |
229 | [record_buffer](absl::string_view) { delete record_buffer; }); |
230 | return WriteRecord(record_serialized); |
231 | #else // TF_CORD_SUPPORT |
232 | return WriteRecord(record.SerializeAsString()); |
233 | #endif // TF_CORD_SUPPORT |
234 | } |
235 | |
236 | std::vector<const TensorBuffer*> tensor_buffers; |
237 | tensor_buffers.reserve(num_simple_); |
238 | std::vector<TensorProto> tensor_protos; |
239 | tensor_protos.reserve(num_complex_); |
240 | experimental::SnapshotTensorMetadata metadata; |
241 | int64_t total_size = 0; |
242 | for (int i = 0, end = tensors.size(); i < end; ++i) { |
243 | const Tensor& tensor = tensors[i]; |
244 | experimental::TensorMetadata* tensor_metadata = |
245 | metadata.add_tensor_metadata(); |
246 | tensor.shape().AsProto(tensor_metadata->mutable_tensor_shape()); |
247 | int64_t size = 0; |
248 | if (simple_tensor_mask_[i]) { |
249 | auto tensor_buffer = DMAHelper::buffer(&tensor); |
250 | tensor_buffers.push_back(tensor_buffer); |
251 | size = tensor_buffer->size(); |
252 | } else { |
253 | TensorProto proto; |
254 | tensor.AsProtoTensorContent(&proto); |
255 | size = proto.ByteSizeLong(); |
256 | tensor_protos.push_back(std::move(proto)); |
257 | } |
258 | tensor_metadata->set_tensor_size_bytes(size); |
259 | total_size += size; |
260 | } |
261 | |
262 | std::vector<char> uncompressed(total_size); |
263 | char* position = uncompressed.data(); |
264 | int buffer_index = 0; |
265 | int proto_index = 0; |
266 | for (int i = 0, end = tensors.size(); i < end; ++i) { |
267 | const auto& tensor_metadata = metadata.tensor_metadata(i); |
268 | if (simple_tensor_mask_[i]) { |
269 | memcpy(position, tensor_buffers[buffer_index]->data(), |
270 | tensor_metadata.tensor_size_bytes()); |
271 | buffer_index++; |
272 | } else { |
273 | tensor_protos[proto_index].SerializeToArray( |
274 | position, tensor_metadata.tensor_size_bytes()); |
275 | proto_index++; |
276 | } |
277 | position += tensor_metadata.tensor_size_bytes(); |
278 | } |
279 | DCHECK_EQ(position, uncompressed.data() + total_size); |
280 | |
281 | string output; |
282 | if (!port::Snappy_Compress(uncompressed.data(), total_size, &output)) { |
283 | return errors::Internal("Failed to compress using snappy." ); |
284 | } |
285 | |
286 | #if defined(TF_CORD_SUPPORT) |
287 | auto metadata_buffer = new std::string(); |
288 | metadata.SerializeToString(metadata_buffer); |
289 | absl::Cord metadata_serialized = absl::MakeCordFromExternal( |
290 | *metadata_buffer, |
291 | [metadata_buffer](absl::string_view) { delete metadata_buffer; }); |
292 | #else |
293 | std::string metadata_serialized = metadata.SerializeAsString(); |
294 | #endif // TF_CORD_SUPPORT |
295 | TF_RETURN_IF_ERROR(WriteRecord(metadata_serialized)); |
296 | TF_RETURN_IF_ERROR(WriteRecord(output)); |
297 | return OkStatus(); |
298 | } |
299 | |
300 | Status CustomWriter::Sync() { return dest_->Sync(); } |
301 | |
302 | Status CustomWriter::Close() { |
303 | if (dest_ != nullptr) { |
304 | TF_RETURN_IF_ERROR(dest_->Close()); |
305 | dest_ = nullptr; |
306 | } |
307 | if (zlib_underlying_dest_ != nullptr) { |
308 | TF_RETURN_IF_ERROR(zlib_underlying_dest_->Close()); |
309 | zlib_underlying_dest_ = nullptr; |
310 | } |
311 | return OkStatus(); |
312 | } |
313 | |
314 | CustomWriter::~CustomWriter() { |
315 | Status s = Close(); |
316 | if (!s.ok()) { |
317 | LOG(ERROR) << "Could not finish writing file: " << s; |
318 | } |
319 | } |
320 | |
321 | Status CustomWriter::WriteRecord(const StringPiece& data) { |
322 | char [kHeaderSize]; |
323 | core::EncodeFixed64(header, data.size()); |
324 | TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); |
325 | return dest_->Append(data); |
326 | } |
327 | |
328 | #if defined(TF_CORD_SUPPORT) |
329 | Status CustomWriter::WriteRecord(const absl::Cord& data) { |
330 | char [kHeaderSize]; |
331 | core::EncodeFixed64(header, data.size()); |
332 | TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); |
333 | return dest_->Append(data); |
334 | } |
335 | #endif // TF_CORD_SUPPORT |
336 | |
337 | Status Reader::Create(Env* env, const std::string& filename, |
338 | const string& compression_type, int version, |
339 | const DataTypeVector& dtypes, |
340 | std::unique_ptr<Reader>* out_reader) { |
341 | switch (version) { |
342 | // CustomReader is able to read a legacy snapshot file format (v0) though |
343 | // custom writer doesn't have the ability to write it any more since it is |
344 | // strictly worse than V1. |
345 | case 0: |
346 | case 1: |
347 | *out_reader = std::make_unique<CustomReader>(filename, compression_type, |
348 | version, dtypes); |
349 | break; |
350 | case 2: |
351 | *out_reader = |
352 | std::make_unique<TFRecordReader>(filename, compression_type, dtypes); |
353 | break; |
354 | default: |
355 | return errors::InvalidArgument("Snapshot reader version: " , version, |
356 | " is not supported." ); |
357 | } |
358 | |
359 | return (*out_reader)->Initialize(env); |
360 | } |
361 | |
362 | Status Reader::SkipRecords(int64_t num_records) { |
363 | // TODO(frankchn): Optimize to not parse the entire Tensor and actually skip. |
364 | for (int i = 0; i < num_records; ++i) { |
365 | std::vector<Tensor> unused_tensors; |
366 | TF_RETURN_IF_ERROR(ReadTensors(&unused_tensors)); |
367 | } |
368 | return OkStatus(); |
369 | } |
370 | |
371 | class Reader::Dataset : public DatasetBase { |
372 | public: |
373 | Dataset(DatasetContext&& ctx, const std::string& shard_dir, |
374 | const std::string& compression, const int64_t version, |
375 | const DataTypeVector& dtypes, |
376 | const std::vector<PartialTensorShape>& shapes, |
377 | const int64_t start_index) |
378 | : DatasetBase(std::move(ctx)), |
379 | shard_dir_(shard_dir), |
380 | compression_(compression), |
381 | version_(version), |
382 | dtypes_(dtypes), |
383 | shapes_(shapes), |
384 | start_index_(start_index) {} |
385 | |
386 | const DataTypeVector& output_dtypes() const override { return dtypes_; } |
387 | |
388 | const std::vector<PartialTensorShape>& output_shapes() const override { |
389 | return shapes_; |
390 | } |
391 | |
392 | std::string DebugString() const override { return "SnapshotDatasetReader" ; } |
393 | |
394 | Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override { |
395 | return OkStatus(); |
396 | } |
397 | |
398 | Status CheckExternalState() const override { return OkStatus(); } |
399 | |
400 | protected: |
401 | Status AsGraphDefInternal(SerializationContext* ctx, |
402 | DatasetGraphDefBuilder* b, |
403 | Node** node) const override { |
404 | Node* shard_dir = nullptr; |
405 | TF_RETURN_IF_ERROR(b->AddScalar(shard_dir_, &shard_dir)); |
406 | |
407 | Node* start_index = nullptr; |
408 | TF_RETURN_IF_ERROR(b->AddScalar(start_index_, &start_index)); |
409 | |
410 | AttrValue compression; |
411 | b->BuildAttrValue(compression_, &compression); |
412 | |
413 | AttrValue version; |
414 | b->BuildAttrValue(version_, &version); |
415 | |
416 | return b->AddDataset( |
417 | this, |
418 | /*inputs=*/ |
419 | {std::make_pair(0, shard_dir), std::make_pair(1, start_index)}, |
420 | /*list_inputs=*/{}, |
421 | /*attrs=*/ |
422 | {{kCompression, compression}, {kVersion, version}}, |
423 | /*use_dataset_name=*/true, node); |
424 | } |
425 | |
426 | std::unique_ptr<IteratorBase> MakeIteratorInternal( |
427 | const string& prefix) const override { |
428 | return std::make_unique<Iterator>(Iterator::Params{ |
429 | this, name_utils::IteratorPrefix(node_name(), prefix)}); |
430 | } |
431 | |
432 | private: |
433 | class Iterator : public DatasetIterator<Dataset> { |
434 | public: |
435 | explicit Iterator(const Params& params) |
436 | : DatasetIterator<Dataset>(params), |
437 | start_index_(dataset()->start_index_) {} |
438 | |
439 | Status Initialize(IteratorContext* ctx) override { |
440 | // TODO(jsimsa): This only needs to happen when we are not restoring but |
441 | // parallel_interleave op implementation caches IteratorContext (and thus |
442 | // the is_restoring bit ends up being inaccurate). |
443 | TF_RETURN_IF_ERROR(Reader::Create( |
444 | ctx->env(), GetCurrentFilename(), dataset()->compression_, |
445 | dataset()->version_, dataset()->dtypes_, &reader_)); |
446 | return AdvanceToStartIndex(ctx); |
447 | } |
448 | |
449 | protected: |
450 | Status GetNextInternal(IteratorContext* ctx, |
451 | std::vector<Tensor>* out_tensors, |
452 | bool* end_of_sequence) override { |
453 | *end_of_sequence = false; |
454 | Status s = reader_->ReadTensors(out_tensors); |
455 | if (!errors::IsOutOfRange(s)) { |
456 | start_index_++; |
457 | return s; |
458 | } |
459 | Status status = AdvanceToNextFile(ctx->env()); |
460 | if (errors::IsNotFound(status)) { |
461 | *end_of_sequence = true; |
462 | return OkStatus(); |
463 | } |
464 | return status; |
465 | } |
466 | |
467 | Status SaveInternal(SerializationContext* ctx, |
468 | IteratorStateWriter* writer) override { |
469 | TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentCheckpointID), |
470 | current_checkpoint_id_)); |
471 | TF_RETURN_IF_ERROR( |
472 | writer->WriteScalar(full_name(kStartIndex), start_index_)); |
473 | return OkStatus(); |
474 | } |
475 | |
476 | Status RestoreInternal(IteratorContext* ctx, |
477 | IteratorStateReader* reader) override { |
478 | TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurrentCheckpointID), |
479 | ¤t_checkpoint_id_)); |
480 | TF_RETURN_IF_ERROR( |
481 | reader->ReadScalar(full_name(kStartIndex), &start_index_)); |
482 | TF_RETURN_IF_ERROR(ctx->env()->FileExists(GetCurrentFilename())); |
483 | TF_RETURN_IF_ERROR(Reader::Create( |
484 | ctx->env(), GetCurrentFilename(), dataset()->compression_, |
485 | dataset()->version_, dataset()->dtypes_, &reader_)); |
486 | return AdvanceToStartIndex(ctx); |
487 | } |
488 | |
489 | private: |
490 | Status AdvanceToNextFile(Env* env) { |
491 | start_index_ = 0; |
492 | current_checkpoint_id_++; |
493 | TF_RETURN_IF_ERROR(env->FileExists(GetCurrentFilename())); |
494 | return Reader::Create(env, GetCurrentFilename(), dataset()->compression_, |
495 | dataset()->version_, dataset()->dtypes_, &reader_); |
496 | } |
497 | |
498 | std::string GetCurrentFilename() { |
499 | return GetCheckpointFileName(dataset()->shard_dir_, |
500 | current_checkpoint_id_); |
501 | } |
502 | |
503 | // TODO(frankchn): Optimize this to not parse every single element. |
504 | Status AdvanceToStartIndex(IteratorContext* ctx) { |
505 | for (int64_t i = 0; i < start_index_; ++i) { |
506 | std::vector<Tensor> unused; |
507 | TF_RETURN_IF_ERROR(reader_->ReadTensors(&unused)); |
508 | } |
509 | return OkStatus(); |
510 | } |
511 | |
512 | std::unique_ptr<Reader> reader_; |
513 | |
514 | // Stores the id current checkpoint file that we are in the process of |
515 | // reading (e.g. if the file is currently 00000001.snapshot, then this will |
516 | // be 1). |
517 | int64_t current_checkpoint_id_ = 0; |
518 | int64_t start_index_; |
519 | }; |
520 | |
521 | const tstring shard_dir_; |
522 | const std::string compression_; |
523 | const int64_t version_; |
524 | const DataTypeVector dtypes_; |
525 | const std::vector<PartialTensorShape> shapes_; |
526 | const int64_t start_index_; |
527 | }; |
528 | |
529 | Reader::DatasetOp::DatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { |
530 | OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_)); |
531 | OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); |
532 | OP_REQUIRES_OK(ctx, ctx->GetAttr(kCompression, &compression_)); |
533 | OP_REQUIRES_OK(ctx, ctx->GetAttr(kVersion, &version_)); |
534 | } |
535 | |
536 | void Reader::DatasetOp::MakeDataset(OpKernelContext* ctx, |
537 | DatasetBase** output) { |
538 | tstring shard_dir; |
539 | OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "shard_dir" , &shard_dir)); |
540 | |
541 | int64_t start_index; |
542 | OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "start_index" , &start_index)); |
543 | |
544 | *output = |
545 | new Reader::Dataset(DatasetContext(ctx), shard_dir, compression_, |
546 | version_, output_types_, output_shapes_, start_index); |
547 | } |
548 | |
549 | class Reader::NestedDataset : public DatasetBase { |
550 | public: |
551 | explicit NestedDataset(DatasetContext&& ctx, |
552 | std::vector<DatasetBase*> datasets) |
553 | : DatasetBase(std::move(ctx)), datasets_(datasets) { |
554 | dtypes_.push_back(DT_VARIANT); |
555 | gtl::InlinedVector<int64_t, 1> element_dim_sizes; |
556 | element_dim_sizes.push_back(1); |
557 | partial_shapes_.emplace_back(element_dim_sizes); |
558 | } |
559 | |
560 | const DataTypeVector& output_dtypes() const override { return dtypes_; } |
561 | |
562 | const std::vector<PartialTensorShape>& output_shapes() const override { |
563 | return partial_shapes_; |
564 | } |
565 | |
566 | std::string DebugString() const override { |
567 | return "SnapshotNestedDatasetReader" ; |
568 | } |
569 | |
570 | Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override { |
571 | inputs->clear(); |
572 | return OkStatus(); |
573 | } |
574 | |
575 | Status CheckExternalState() const override { return OkStatus(); } |
576 | |
577 | protected: |
578 | Status AsGraphDefInternal(SerializationContext* ctx, |
579 | DatasetGraphDefBuilder* b, |
580 | Node** node) const override { |
581 | std::vector<Node*> input_graph_nodes; |
582 | input_graph_nodes.reserve(datasets_.size()); |
583 | for (const auto& dataset : datasets_) { |
584 | Node* input_node; |
585 | TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, dataset, &input_node)); |
586 | input_graph_nodes.emplace_back(input_node); |
587 | } |
588 | TF_RETURN_IF_ERROR( |
589 | b->AddDataset(this, /*inputs=*/{}, |
590 | /*list_inputs=*/{std::make_pair(0, input_graph_nodes)}, |
591 | /*attrs=*/{}, node)); |
592 | return OkStatus(); |
593 | } |
594 | |
595 | std::unique_ptr<IteratorBase> MakeIteratorInternal( |
596 | const string& prefix) const override { |
597 | return std::make_unique<Iterator>(Iterator::Params{ |
598 | this, name_utils::IteratorPrefix(node_name(), prefix)}); |
599 | } |
600 | |
601 | private: |
602 | std::vector<DatasetBase*> datasets_; |
603 | DataTypeVector dtypes_; |
604 | std::vector<PartialTensorShape> partial_shapes_; |
605 | |
606 | class Iterator : public DatasetIterator<NestedDataset> { |
607 | public: |
608 | explicit Iterator(const Params& params) |
609 | : DatasetIterator<NestedDataset>(params) {} |
610 | |
611 | protected: |
612 | Status GetNextInternal(IteratorContext* ctx, |
613 | std::vector<Tensor>* out_tensors, |
614 | bool* end_of_sequence) override { |
615 | const int64_t num_datasets = dataset()->datasets_.size(); |
616 | *end_of_sequence = num_datasets == index_; |
617 | if (!*end_of_sequence) { |
618 | Tensor tensor(DT_VARIANT, TensorShape({})); |
619 | |
620 | TF_RETURN_IF_ERROR( |
621 | StoreDatasetInVariantTensor(dataset()->datasets_[index_], &tensor)); |
622 | out_tensors->clear(); |
623 | out_tensors->push_back(std::move(tensor)); |
624 | |
625 | index_++; |
626 | } |
627 | return OkStatus(); |
628 | } |
629 | |
630 | Status SaveInternal(SerializationContext* ctx, |
631 | IteratorStateWriter* writer) override { |
632 | TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_)); |
633 | return OkStatus(); |
634 | } |
635 | |
636 | Status RestoreInternal(IteratorContext* ctx, |
637 | IteratorStateReader* reader) override { |
638 | TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &index_)); |
639 | return OkStatus(); |
640 | } |
641 | |
642 | private: |
643 | int64_t index_ = 0; |
644 | }; |
645 | }; |
646 | |
647 | Reader::NestedDatasetOp::NestedDatasetOp(OpKernelConstruction* ctx) |
648 | : DatasetOpKernel(ctx) { |
649 | OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_)); |
650 | OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); |
651 | } |
652 | |
653 | void Reader::NestedDatasetOp::MakeDataset(OpKernelContext* ctx, |
654 | DatasetBase** output) { |
655 | std::vector<DatasetBase*> inputs; |
656 | for (size_t i = 0; i < ctx->num_inputs(); ++i) { |
657 | DatasetBase* input; |
658 | OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); |
659 | inputs.push_back(input); |
660 | } |
661 | *output = new Reader::NestedDataset(DatasetContext(ctx), inputs); |
662 | (*output)->Initialize(/*metadata=*/{}); |
663 | } |
664 | |
665 | Status Reader::MakeNestedDataset(Env* env, |
666 | const std::vector<std::string>& shard_dirs, |
667 | const string& compression_type, int version, |
668 | const DataTypeVector& dtypes, |
669 | const std::vector<PartialTensorShape>& shapes, |
670 | const int64_t start_index, |
671 | DatasetBase** output) { |
672 | std::vector<DatasetBase*> datasets; |
673 | |
674 | datasets.reserve(shard_dirs.size()); |
675 | for (int64_t i = 0; i < shard_dirs.size(); ++i) { |
676 | // TODO(frankchn): The reading pattern could be controlled in a non-round |
677 | // robin fashion, so we cannot assume a round-robin manner when restoring. |
678 | int64_t dataset_start_index = start_index / shard_dirs.size(); |
679 | if (start_index % shard_dirs.size() > datasets.size()) { |
680 | dataset_start_index++; |
681 | } |
682 | |
683 | datasets.push_back( |
684 | new Dataset(DatasetContext(DatasetContext::Params( |
685 | {"SnapshotDatasetReader" , |
686 | strings::StrCat("SnapshotDatasetReader/_" , i)})), |
687 | shard_dirs.at(i), compression_type, version, dtypes, shapes, |
688 | dataset_start_index)); |
689 | datasets.back()->Initialize(/*metadata=*/{}); |
690 | } |
691 | |
692 | // Rotate the vector such that the first dataset contains the next element |
693 | // to be produced, but not if there are no shards at all (then we just |
694 | // construct an empty dataset). |
695 | if (!shard_dirs.empty()) { |
696 | std::rotate(datasets.begin(), |
697 | datasets.begin() + (start_index % shard_dirs.size()), |
698 | datasets.end()); |
699 | } |
700 | |
701 | *output = new NestedDataset( |
702 | DatasetContext(DatasetContext::Params( |
703 | {"SnapshotNestedDatasetReader" , "SnapshotNestedDatasetReader" })), |
704 | datasets); |
705 | (*output)->Initialize(/*metadata=*/{}); |
706 | return OkStatus(); |
707 | } |
708 | |
709 | TFRecordReader::TFRecordReader(const std::string& filename, |
710 | const string& compression_type, |
711 | const DataTypeVector& dtypes) |
712 | : filename_(filename), |
713 | offset_(0), |
714 | compression_type_(compression_type), |
715 | dtypes_(dtypes) {} |
716 | |
717 | Status TFRecordReader::Initialize(Env* env) { |
718 | TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_)); |
719 | |
720 | record_reader_ = std::make_unique<io::RecordReader>( |
721 | file_.get(), io::RecordReaderOptions::CreateRecordReaderOptions( |
722 | /*compression_type=*/compression_type_)); |
723 | return OkStatus(); |
724 | } |
725 | |
726 | Status TFRecordReader::ReadTensors(std::vector<Tensor>* read_tensors) { |
727 | read_tensors->reserve(dtypes_.size()); |
728 | for (int i = 0; i < dtypes_.size(); ++i) { |
729 | tstring record; |
730 | TF_RETURN_IF_ERROR(record_reader_->ReadRecord(&offset_, &record)); |
731 | |
732 | TensorProto proto; |
733 | proto.ParseFromArray(record.data(), record.size()); |
734 | |
735 | Tensor tensor; |
736 | if (!tensor.FromProto(proto)) { |
737 | return errors::DataLoss("Unable to parse tensor from stored proto." ); |
738 | } |
739 | |
740 | read_tensors->push_back(std::move(tensor)); |
741 | } |
742 | return OkStatus(); |
743 | } |
744 | |
745 | CustomReader::CustomReader(const std::string& filename, |
746 | const string& compression_type, const int version, |
747 | const DataTypeVector& dtypes) |
748 | : filename_(filename), |
749 | compression_type_(compression_type), |
750 | version_(version), |
751 | dtypes_(dtypes) {} |
752 | |
753 | Status CustomReader::Initialize(Env* env) { |
754 | TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_)); |
755 | input_stream_ = std::make_unique<io::RandomAccessInputStream>(file_.get()); |
756 | |
757 | #if defined(IS_SLIM_BUILD) |
758 | if (compression_type_ != io::compression::kNone) { |
759 | LOG(ERROR) << "Compression is unsupported on mobile platforms. Turning " |
760 | << "off compression." ; |
761 | } |
762 | #else // IS_SLIM_BUILD |
763 | if (compression_type_ == io::compression::kGzip) { |
764 | io::ZlibCompressionOptions zlib_options; |
765 | zlib_options = io::ZlibCompressionOptions::GZIP(); |
766 | |
767 | input_stream_ = std::make_unique<io::ZlibInputStream>( |
768 | input_stream_.release(), zlib_options.input_buffer_size, |
769 | zlib_options.output_buffer_size, zlib_options, true); |
770 | } else if (compression_type_ == io::compression::kSnappy) { |
771 | if (version_ == 0) { |
772 | input_stream_ = std::make_unique<io::SnappyInputBuffer>( |
773 | file_.get(), /*input_buffer_bytes=*/kSnappyReaderInputBufferSizeBytes, |
774 | /*output_buffer_bytes=*/kSnappyReaderOutputBufferSizeBytes); |
775 | } else { |
776 | input_stream_ = |
777 | std::make_unique<io::BufferedInputStream>(file_.get(), 64 << 20); |
778 | } |
779 | } |
780 | #endif // IS_SLIM_BUILD |
781 | simple_tensor_mask_.reserve(dtypes_.size()); |
782 | for (const auto& dtype : dtypes_) { |
783 | if (DataTypeCanUseMemcpy(dtype)) { |
784 | simple_tensor_mask_.push_back(true); |
785 | num_simple_++; |
786 | } else { |
787 | simple_tensor_mask_.push_back(false); |
788 | num_complex_++; |
789 | } |
790 | } |
791 | |
792 | return OkStatus(); |
793 | } |
794 | |
795 | Status CustomReader::ReadTensors(std::vector<Tensor>* read_tensors) { |
796 | profiler::TraceMe activity( |
797 | [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors" ); }, |
798 | profiler::TraceMeLevel::kInfo); |
799 | if (version_ == 0 || compression_type_ != io::compression::kSnappy) { |
800 | return ReadTensorsV0(read_tensors); |
801 | } |
802 | if (version_ != 1) { |
803 | return errors::InvalidArgument("Version: " , version_, " is not supported." ); |
804 | } |
805 | if (compression_type_ != io::compression::kSnappy) { |
806 | return errors::InvalidArgument("Compression " , compression_type_, |
807 | " is not supported." ); |
808 | } |
809 | |
810 | experimental::SnapshotTensorMetadata metadata; |
811 | tstring metadata_str; |
812 | TF_RETURN_IF_ERROR(ReadRecord(&metadata_str)); |
813 | if (!metadata.ParseFromArray(metadata_str.data(), metadata_str.size())) { |
814 | return errors::DataLoss("Could not parse SnapshotTensorMetadata" ); |
815 | } |
816 | read_tensors->reserve(metadata.tensor_metadata_size()); |
817 | |
818 | std::vector<Tensor> simple_tensors; |
819 | simple_tensors.reserve(num_simple_); |
820 | std::vector<std::pair<std::unique_ptr<char[]>, size_t>> tensor_proto_strs; |
821 | tensor_proto_strs.reserve(num_complex_); |
822 | TF_RETURN_IF_ERROR( |
823 | SnappyUncompress(&metadata, &simple_tensors, &tensor_proto_strs)); |
824 | |
825 | int simple_index = 0; |
826 | int complex_index = 0; |
827 | for (int i = 0, end = simple_tensor_mask_.size(); i < end; ++i) { |
828 | if (simple_tensor_mask_[i]) { |
829 | read_tensors->push_back(std::move(simple_tensors[simple_index])); |
830 | simple_index++; |
831 | } else { |
832 | auto tensor_proto_str = std::move(tensor_proto_strs[complex_index].first); |
833 | size_t tensor_proto_size = tensor_proto_strs[complex_index].second; |
834 | TensorProto tp; |
835 | if (!tp.ParseFromArray(tensor_proto_str.get(), tensor_proto_size)) { |
836 | return errors::Internal("Could not parse TensorProto" ); |
837 | } |
838 | Tensor t; |
839 | if (!t.FromProto(tp)) { |
840 | return errors::Internal("Could not parse Tensor" ); |
841 | } |
842 | read_tensors->push_back(std::move(t)); |
843 | complex_index++; |
844 | } |
845 | } |
846 | return OkStatus(); |
847 | } |
848 | |
849 | Status CustomReader::ReadTensorsV0(std::vector<Tensor>* read_tensors) { |
850 | experimental::SnapshotRecord record; |
851 | #if defined(PLATFORM_GOOGLE) |
852 | absl::Cord c; |
853 | TF_RETURN_IF_ERROR(ReadRecord(&c)); |
854 | record.ParseFromCord(c); |
855 | #else // PLATFORM_GOOGLE |
856 | tstring record_bytes; |
857 | TF_RETURN_IF_ERROR(ReadRecord(&record_bytes)); |
858 | record.ParseFromArray(record_bytes.data(), record_bytes.size()); |
859 | #endif // PLATFORM_GOOGLE |
860 | read_tensors->reserve(record.tensor_size()); |
861 | for (int i = 0; i < record.tensor_size(); ++i) { |
862 | read_tensors->emplace_back(); |
863 | if (!read_tensors->back().FromProto(record.tensor(i))) { |
864 | return errors::DataLoss("Unable to parse tensor from proto." ); |
865 | } |
866 | } |
867 | return OkStatus(); |
868 | } |
869 | |
870 | Status CustomReader::SnappyUncompress( |
871 | const experimental::SnapshotTensorMetadata* metadata, |
872 | std::vector<Tensor>* simple_tensors, |
873 | std::vector<std::pair<std::unique_ptr<char[]>, size_t>>* |
874 | tensor_proto_strs) { |
875 | tstring compressed; |
876 | TF_RETURN_IF_ERROR(ReadRecord(&compressed)); |
877 | size_t size; |
878 | if (!port::Snappy_GetUncompressedLength(compressed.data(), compressed.size(), |
879 | &size)) { |
880 | return errors::Internal("Could not get snappy uncompressed length" ); |
881 | } |
882 | |
883 | int num_tensors = metadata->tensor_metadata_size(); |
884 | std::vector<struct iovec> iov(num_tensors); |
885 | int index = 0; |
886 | int64_t total_size = 0; |
887 | for (int i = 0, end = simple_tensor_mask_.size(); i < end; ++i) { |
888 | const auto& tensor_metadata = metadata->tensor_metadata(i); |
889 | if (simple_tensor_mask_[i]) { |
890 | TensorShape shape(tensor_metadata.tensor_shape()); |
891 | Tensor simple_tensor(dtypes_[i], shape); |
892 | TensorBuffer* buffer = DMAHelper::buffer(&simple_tensor); |
893 | iov[index].iov_base = buffer->data(); |
894 | iov[index].iov_len = buffer->size(); |
895 | simple_tensors->push_back(std::move(simple_tensor)); |
896 | } else { |
897 | auto tensor_proto_str = |
898 | std::make_unique<char[]>(tensor_metadata.tensor_size_bytes()); |
899 | iov[index].iov_base = tensor_proto_str.get(); |
900 | iov[index].iov_len = tensor_metadata.tensor_size_bytes(); |
901 | tensor_proto_strs->push_back(std::make_pair( |
902 | std::move(tensor_proto_str), tensor_metadata.tensor_size_bytes())); |
903 | } |
904 | total_size += iov[index].iov_len; |
905 | index++; |
906 | } |
907 | const int64_t size_int = size; |
908 | if (size_int != total_size) { |
909 | return errors::Internal("Uncompressed size mismatch. Snappy expects " , size, |
910 | " whereas the tensor metadata suggests " , |
911 | total_size); |
912 | } |
913 | if (!port::Snappy_UncompressToIOVec(compressed.data(), compressed.size(), |
914 | iov.data(), num_tensors)) { |
915 | return errors::Internal("Failed to perform snappy decompression." ); |
916 | } |
917 | return OkStatus(); |
918 | } |
919 | |
920 | Status CustomReader::ReadRecord(tstring* record) { |
921 | tstring ; |
922 | TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); |
923 | uint64 length = core::DecodeFixed64(header.data()); |
924 | return input_stream_->ReadNBytes(length, record); |
925 | } |
926 | |
927 | #if defined(TF_CORD_SUPPORT) |
928 | Status CustomReader::ReadRecord(absl::Cord* record) { |
929 | tstring ; |
930 | TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); |
931 | uint64 length = core::DecodeFixed64(header.data()); |
932 | if (compression_type_ == io::compression::kNone) { |
933 | return input_stream_->ReadNBytes(length, record); |
934 | } else { |
935 | auto tmp_str = new tstring(); |
936 | TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(length, tmp_str)); |
937 | absl::string_view tmp_str_view(*tmp_str); |
938 | record->Append(absl::MakeCordFromExternal( |
939 | tmp_str_view, [tmp_str](absl::string_view) { delete tmp_str; })); |
940 | return OkStatus(); |
941 | } |
942 | } |
943 | #endif // TF_CORD_SUPPORT |
944 | |
945 | Status WriteMetadataFile(Env* env, const string& dir, |
946 | const experimental::SnapshotMetadataRecord* metadata) { |
947 | string metadata_filename = io::JoinPath(dir, kMetadataFilename); |
948 | TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(dir)); |
949 | std::string tmp_filename = |
950 | absl::StrCat(metadata_filename, "-tmp-" , random::New64()); |
951 | TF_RETURN_IF_ERROR(WriteBinaryProto(env, tmp_filename, *metadata)); |
952 | return env->RenameFile(tmp_filename, metadata_filename); |
953 | } |
954 | |
955 | Status ReadMetadataFile(Env* env, const string& dir, |
956 | experimental::SnapshotMetadataRecord* metadata, |
957 | bool* file_exists) { |
958 | string metadata_filename = io::JoinPath(dir, kMetadataFilename); |
959 | Status s = env->FileExists(metadata_filename); |
960 | *file_exists = s.ok(); |
961 | |
962 | if (*file_exists) { |
963 | return ReadBinaryProto(env, metadata_filename, metadata); |
964 | } else { |
965 | return OkStatus(); |
966 | } |
967 | } |
968 | |
969 | Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash, |
970 | const GraphDef* graph) { |
971 | std::string hash_hex = |
972 | strings::StrCat(strings::Hex(hash, strings::kZeroPad16)); |
973 | std::string graph_file = |
974 | io::JoinPath(path, absl::StrCat(hash_hex, "-graph.pbtxt" )); |
975 | |
976 | LOG(INFO) << "Graph hash is " << hash_hex << ", writing to " << graph_file; |
977 | TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(path)); |
978 | return WriteTextProto(env, graph_file, *graph); |
979 | } |
980 | |
981 | Status DetermineOpState(const std::string& mode_string, bool file_exists, |
982 | const experimental::SnapshotMetadataRecord* metadata, |
983 | const uint64 pending_snapshot_expiry_seconds, |
984 | Mode* mode) { |
985 | if (mode_string == kModeRead) { |
986 | // In read mode, we should expect a metadata file is written. |
987 | if (!file_exists) { |
988 | return errors::NotFound("Metadata file does not exist." ); |
989 | } |
990 | LOG(INFO) << "Overriding mode to reader." ; |
991 | *mode = READER; |
992 | return OkStatus(); |
993 | } |
994 | |
995 | if (mode_string == kModeWrite) { |
996 | LOG(INFO) << "Overriding mode to writer." ; |
997 | *mode = WRITER; |
998 | return OkStatus(); |
999 | } |
1000 | |
1001 | if (mode_string == kModePassthrough) { |
1002 | LOG(INFO) << "Overriding mode to passthrough." ; |
1003 | *mode = PASSTHROUGH; |
1004 | return OkStatus(); |
1005 | } |
1006 | |
1007 | if (!file_exists) { |
1008 | *mode = WRITER; |
1009 | return OkStatus(); |
1010 | } |
1011 | |
1012 | if (metadata->finalized()) { |
1013 | // File found, snapshot has been finalized. |
1014 | *mode = READER; |
1015 | return OkStatus(); |
1016 | } |
1017 | |
1018 | int64_t expiration_timer = static_cast<int64_t>(EnvTime::NowMicros()) - |
1019 | pending_snapshot_expiry_seconds * 1000000; |
1020 | |
1021 | if (metadata->creation_timestamp() >= expiration_timer) { |
1022 | // Someone else is already writing and time has not expired. |
1023 | *mode = PASSTHROUGH; |
1024 | return OkStatus(); |
1025 | } else { |
1026 | // Time has expired, we write regardless. |
1027 | *mode = WRITER; |
1028 | return OkStatus(); |
1029 | } |
1030 | } |
1031 | |
1032 | AsyncWriter::AsyncWriter(Env* env, int64_t file_index, |
1033 | const std::string& shard_directory, |
1034 | uint64 checkpoint_id, const std::string& compression, |
1035 | int64_t version, const DataTypeVector& output_types, |
1036 | std::function<void(Status)> done) { |
1037 | thread_ = absl::WrapUnique(env->StartThread( |
1038 | ThreadOptions(), absl::StrCat("writer_thread_" , file_index), |
1039 | [this, env, shard_directory, checkpoint_id, compression, version, |
1040 | &output_types, done = std::move(done)] { |
1041 | done(WriterThread(env, shard_directory, checkpoint_id, compression, |
1042 | version, output_types)); |
1043 | })); |
1044 | } |
1045 | |
1046 | void AsyncWriter::Write(const std::vector<Tensor>& tensors) { |
1047 | mutex_lock l(mu_); |
1048 | ElementOrEOF element; |
1049 | element.value = tensors; |
1050 | deque_.push_back(std::move(element)); |
1051 | } |
1052 | |
1053 | void AsyncWriter::SignalEOF() { |
1054 | mutex_lock l(mu_); |
1055 | ElementOrEOF be; |
1056 | be.end_of_sequence = true; |
1057 | deque_.push_back(std::move(be)); |
1058 | } |
1059 | |
1060 | void AsyncWriter::Consume(ElementOrEOF* be) { |
1061 | mutex_lock l(mu_); |
1062 | mu_.Await(tensorflow::Condition(this, &AsyncWriter::ElementAvailable)); |
1063 | *be = deque_.front(); |
1064 | deque_.pop_front(); |
1065 | } |
1066 | |
1067 | bool AsyncWriter::ElementAvailable() { return !deque_.empty(); } |
1068 | |
1069 | Status AsyncWriter::WriterThread(Env* env, const std::string& shard_directory, |
1070 | uint64 checkpoint_id, |
1071 | const std::string& compression, |
1072 | int64_t version, DataTypeVector output_types) { |
1073 | std::unique_ptr<snapshot_util::Writer> writer; |
1074 | TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory)); |
1075 | |
1076 | TF_RETURN_IF_ERROR(snapshot_util::Writer::Create( |
1077 | env, GetCheckpointFileName(shard_directory, checkpoint_id), compression, |
1078 | version, std::move(output_types), &writer)); |
1079 | |
1080 | while (true) { |
1081 | ElementOrEOF be; |
1082 | Consume(&be); |
1083 | |
1084 | if (be.end_of_sequence) { |
1085 | TF_RETURN_IF_ERROR(writer->Close()); |
1086 | break; |
1087 | } |
1088 | |
1089 | TF_RETURN_IF_ERROR(writer->WriteTensors(be.value)); |
1090 | } |
1091 | return OkStatus(); |
1092 | } |
1093 | |
1094 | namespace { |
1095 | |
1096 | REGISTER_KERNEL_BUILDER(Name("SnapshotDatasetReader" ).Device(DEVICE_CPU), |
1097 | Reader::DatasetOp); |
1098 | REGISTER_KERNEL_BUILDER(Name("SnapshotNestedDatasetReader" ).Device(DEVICE_CPU), |
1099 | Reader::NestedDatasetOp); |
1100 | |
1101 | } // namespace |
1102 | } // namespace snapshot_util |
1103 | } // namespace data |
1104 | } // namespace tensorflow |
1105 | |