1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_
17#define TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_
18
19#include "tensorflow/core/framework/dataset.h"
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/framework/types.h"
23#include "tensorflow/core/lib/io/compression.h"
24#include "tensorflow/core/lib/io/inputstream_interface.h"
25#include "tensorflow/core/lib/io/record_reader.h"
26#include "tensorflow/core/lib/io/record_writer.h"
27#include "tensorflow/core/platform/env.h"
28#include "tensorflow/core/platform/file_system.h"
29#include "tensorflow/core/platform/path.h"
30#include "tensorflow/core/platform/status.h"
31
32namespace tensorflow {
33
34class GraphDef;
35
36namespace data {
37
38namespace experimental {
39
40class SnapshotMetadataRecord;
41class SnapshotTensorMetadata;
42
43} // namespace experimental
44
45namespace snapshot_util {
46
47constexpr char kMetadataFilename[] = "snapshot.metadata";
48
49constexpr char kModeAuto[] = "auto";
50constexpr char kModeWrite[] = "write";
51constexpr char kModeRead[] = "read";
52constexpr char kModePassthrough[] = "passthrough";
53constexpr char kShardDirectorySuffix[] = ".shard";
54
55enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 };
56
57// Returns the name of the "hash" directory for the given base path and hash ID.
58std::string HashDirectory(const std::string& path, uint64 hash);
59
60// Returns the name of the "run" directory for the given base path and run ID.
61std::string RunDirectory(const std::string& hash_directory, uint64 run_id);
62std::string RunDirectory(const std::string& hash_directory,
63 const std::string& run_id);
64
65// Returns the name of the "shard" directory for the given base path and shard
66// ID.
67std::string ShardDirectory(const std::string& run_directory, int64_t shard_id);
68
69// Returns the checkpoint file name for the given directory and checkpoint ID.
70std::string GetCheckpointFileName(const std::string& shard_directory,
71 const uint64 checkpoint_id);
72
73// This is a interface class that exposes snapshot writing functionality.
74class Writer {
75 public:
76 // Creates a new writer object.
77 static Status Create(Env* env, const std::string& filename,
78 const std::string& compression_type, int version,
79 const DataTypeVector& dtypes,
80 std::unique_ptr<Writer>* out_writer);
81
82 // Writes a vector of tensors to the snapshot writer file.
83 virtual Status WriteTensors(const std::vector<Tensor>& tensors) = 0;
84
85 // Flushes any in-memory buffers to disk.
86 virtual Status Sync() = 0;
87
88 // Closes and finalizes the snapshot file. All calls to any other method will
89 // be invalid after this call.
90 virtual Status Close() = 0;
91
92 virtual ~Writer() {}
93
94 protected:
95 virtual Status Initialize(tensorflow::Env* env) = 0;
96};
97
98// Writes snapshots with the standard TFRecord file format.
99class TFRecordWriter : public Writer {
100 public:
101 TFRecordWriter(const std::string& filename,
102 const std::string& compression_type);
103
104 Status WriteTensors(const std::vector<Tensor>& tensors) override;
105
106 Status Sync() override;
107
108 Status Close() override;
109
110 ~TFRecordWriter() override;
111
112 protected:
113 Status Initialize(tensorflow::Env* env) override;
114
115 private:
116 const std::string filename_;
117 const std::string compression_type_;
118
119 std::unique_ptr<WritableFile> dest_;
120 std::unique_ptr<io::RecordWriter> record_writer_;
121};
122
123// Writes snapshot with a custom (legacy) file format.
124class CustomWriter : public Writer {
125 public:
126 static constexpr const size_t kHeaderSize = sizeof(uint64);
127
128 static constexpr const char* const kClassName = "SnapshotWriter";
129 static constexpr const char* const kWriteStringPiece = "WriteStringPiece";
130 static constexpr const char* const kWriteCord = "WriteCord";
131 static constexpr const char* const kSeparator = "::";
132
133 CustomWriter(const std::string& filename, const std::string& compression_type,
134 const DataTypeVector& dtypes);
135
136 Status WriteTensors(const std::vector<Tensor>& tensors) override;
137
138 Status Sync() override;
139
140 Status Close() override;
141
142 ~CustomWriter() override;
143
144 protected:
145 Status Initialize(tensorflow::Env* env) override;
146
147 private:
148 Status WriteRecord(const StringPiece& data);
149
150#if defined(TF_CORD_SUPPORT)
151 Status WriteRecord(const absl::Cord& data);
152#endif // TF_CORD_SUPPORT
153
154 std::unique_ptr<WritableFile> dest_;
155 const std::string filename_;
156 const std::string compression_type_;
157 const DataTypeVector dtypes_;
158 // We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that
159 // in dest_ if we want compression. ZlibOutputBuffer doesn't own the original
160 // dest_ and so we need somewhere to store the original one.
161 std::unique_ptr<WritableFile> zlib_underlying_dest_;
162 std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
163 int num_simple_ = 0;
164 int num_complex_ = 0;
165};
166
167// Interface class for reading snapshot files previous written with Writer.
168class Reader {
169 public:
170 // Op kernel that creates an instance of `Reader::Dataset` needed to support
171 // serialization and deserialization of `Reader::Dataset`.
172 class DatasetOp : public DatasetOpKernel {
173 public:
174 explicit DatasetOp(OpKernelConstruction* ctx);
175
176 protected:
177 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
178
179 private:
180 DataTypeVector output_types_;
181 std::vector<PartialTensorShape> output_shapes_;
182 std::string compression_;
183 int64_t version_;
184 };
185
186 // Op kernel that creates an instance of `Reader::NestedDataset` needed to
187 // support serialization and deserialization of `Reader::NestedDataset`.
188 class NestedDatasetOp : public DatasetOpKernel {
189 public:
190 explicit NestedDatasetOp(OpKernelConstruction* ctx);
191
192 protected:
193 void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
194
195 private:
196 DataTypeVector output_types_;
197 std::vector<PartialTensorShape> output_shapes_;
198 };
199
200 // Creates a new Reader object that reads data from `filename`. Note that
201 // the `version`, `compression_type`, and `dtypes` arguments passed into
202 // `Writer` and `Reader` must be the same for the reading to succeed.
203 static Status Create(Env* env, const std::string& filename,
204 const string& compression_type, int version,
205 const DataTypeVector& dtypes,
206 std::unique_ptr<Reader>* out_reader);
207
208 // Returns a nested dataset for a set of given snapshot file names.
209 //
210 // This function takes a vector of snapshot files, and returns a nested
211 // dataset. Each element within the nested dataset is itself a dataset, and
212 // contains all the elements written out to each individual snapshot file.
213 static Status MakeNestedDataset(Env* env,
214 const std::vector<std::string>& shard_dirs,
215 const string& compression_type, int version,
216 const DataTypeVector& dtypes,
217 const std::vector<PartialTensorShape>& shapes,
218 const int64_t start_index,
219 DatasetBase** output);
220
221 // Reads a vector of Tensors from the snapshot file.
222 virtual Status ReadTensors(std::vector<Tensor>* read_tensors) = 0;
223
224 // Skips `num_records`. Equivalent to calling `ReadTensors` `num_records`
225 // times then discarding the results.
226 virtual Status SkipRecords(int64_t num_records);
227
228 virtual ~Reader() {}
229
230 protected:
231 virtual Status Initialize(Env* env) = 0;
232
233 class Dataset;
234 class NestedDataset;
235};
236
237// Reads snapshots previously written with `TFRecordWriter`.
238class TFRecordReader : public Reader {
239 public:
240 TFRecordReader(const std::string& filename, const string& compression_type,
241 const DataTypeVector& dtypes);
242
243 Status ReadTensors(std::vector<Tensor>* read_tensors) override;
244
245 ~TFRecordReader() override {}
246
247 protected:
248 Status Initialize(Env* env) override;
249
250 private:
251 std::string filename_;
252 std::unique_ptr<RandomAccessFile> file_;
253 std::unique_ptr<io::RecordReader> record_reader_;
254 uint64 offset_;
255
256 const string compression_type_;
257 const DataTypeVector dtypes_;
258};
259
260// Reads snapshots previously written with `CustomWriter`.
261class CustomReader : public Reader {
262 public:
263 // The reader input buffer size is deliberately large because the input reader
264 // will throw an error if the compressed block length cannot fit in the input
265 // buffer.
266 static constexpr const int64_t kSnappyReaderInputBufferSizeBytes =
267 1 << 30; // 1 GiB
268 // TODO(b/148804377): Set this in a smarter fashion.
269 static constexpr const int64_t kSnappyReaderOutputBufferSizeBytes =
270 32 << 20; // 32 MiB
271 static constexpr const size_t kHeaderSize = sizeof(uint64);
272
273 static constexpr const char* const kClassName = "SnapshotReader";
274 static constexpr const char* const kReadString = "ReadString";
275 static constexpr const char* const kReadCord = "ReadCord";
276 static constexpr const char* const kSeparator = "::";
277
278 CustomReader(const std::string& filename, const string& compression_type,
279 const int version, const DataTypeVector& dtypes);
280
281 Status ReadTensors(std::vector<Tensor>* read_tensors) override;
282
283 ~CustomReader() override {}
284
285 protected:
286 Status Initialize(Env* env) override;
287
288 private:
289 Status ReadTensorsV0(std::vector<Tensor>* read_tensors);
290
291 Status SnappyUncompress(
292 const experimental::SnapshotTensorMetadata* metadata,
293 std::vector<Tensor>* simple_tensors,
294 std::vector<std::pair<std::unique_ptr<char[]>, size_t>>*
295 tensor_proto_strs);
296
297 Status ReadRecord(tstring* record);
298
299#if defined(TF_CORD_SUPPORT)
300 Status ReadRecord(absl::Cord* record);
301#endif
302
303 std::string filename_;
304 std::unique_ptr<RandomAccessFile> file_;
305 std::unique_ptr<io::InputStreamInterface> input_stream_;
306 const string compression_type_;
307 const int version_;
308 const DataTypeVector dtypes_;
309 int num_simple_ = 0;
310 int num_complex_ = 0;
311 std::vector<bool> simple_tensor_mask_; // true for simple, false for complex.
312};
313
314// Writes snapshot metadata to the given directory.
315Status WriteMetadataFile(Env* env, const string& dir,
316 const experimental::SnapshotMetadataRecord* metadata);
317
318// Reads snapshot metadata from the given directory.
319Status ReadMetadataFile(Env* env, const string& dir,
320 experimental::SnapshotMetadataRecord* metadata,
321 bool* file_exists);
322
323// Writes a dataset graph to the given directory.
324Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash,
325 const GraphDef* graph);
326
327Status DetermineOpState(const std::string& mode_string, bool file_exists,
328 const experimental::SnapshotMetadataRecord* metadata,
329 const uint64 pending_snapshot_expiry_seconds,
330 Mode* mode);
331
332// Represents a dataset element or EOF.
333struct ElementOrEOF {
334 std::vector<Tensor> value;
335 bool end_of_sequence = false;
336};
337
338// AsyncWriter provides API for asynchronously writing dataset elements
339// (each represented as a vector of tensors) to a file.
340//
341// The expected use of this API is:
342//
343// std::unique_ptr<AsyncWriter> writer = absl_make_unique<AsyncWriter>(...);
344//
345// while (data_available()) {
346// std::vector<Tensor> data = read_data()
347// writer->Write(data);
348// }
349// writer->SignalEOF();
350// writer = nullptr; // This will block until writes are flushed.
351class AsyncWriter {
352 public:
353 explicit AsyncWriter(Env* env, int64_t file_index,
354 const std::string& shard_directory, uint64 checkpoint_id,
355 const std::string& compression, int64_t version,
356 const DataTypeVector& output_types,
357 std::function<void(Status)> done);
358
359 // Writes the given tensors. The method is non-blocking and returns without
360 // waiting for the element to be written.
361 void Write(const std::vector<Tensor>& tensors) TF_LOCKS_EXCLUDED(mu_);
362
363 // Signals the end of input. The method is non-blocking and returns without
364 // waiting for the writer to be closed.
365 void SignalEOF() TF_LOCKS_EXCLUDED(mu_);
366
367 private:
368 void Consume(ElementOrEOF* be) TF_LOCKS_EXCLUDED(mu_);
369 bool ElementAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
370 Status WriterThread(Env* env, const std::string& shard_directory,
371 uint64 checkpoint_id, const std::string& compression,
372 int64_t version, DataTypeVector output_types);
373
374 mutex mu_;
375 std::deque<ElementOrEOF> deque_ TF_GUARDED_BY(mu_);
376
377 // This has to be last. During destruction, we need to make sure that the
378 // Thread object is destroyed first as its destructor blocks on thread
379 // completion. If there are other member variables after this, they may get
380 // destroyed first before the thread finishes, potentially causing the
381 // thread to access invalid memory.
382 std::unique_ptr<Thread> thread_;
383};
384
385} // namespace snapshot_util
386} // namespace data
387} // namespace tensorflow
388
389#endif // TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_
390