1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_ |
17 | #define TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_ |
18 | |
19 | #include "tensorflow/core/framework/dataset.h" |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/tensor.h" |
22 | #include "tensorflow/core/framework/types.h" |
23 | #include "tensorflow/core/lib/io/compression.h" |
24 | #include "tensorflow/core/lib/io/inputstream_interface.h" |
25 | #include "tensorflow/core/lib/io/record_reader.h" |
26 | #include "tensorflow/core/lib/io/record_writer.h" |
27 | #include "tensorflow/core/platform/env.h" |
28 | #include "tensorflow/core/platform/file_system.h" |
29 | #include "tensorflow/core/platform/path.h" |
30 | #include "tensorflow/core/platform/status.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | class GraphDef; |
35 | |
36 | namespace data { |
37 | |
38 | namespace experimental { |
39 | |
40 | class SnapshotMetadataRecord; |
41 | class SnapshotTensorMetadata; |
42 | |
43 | } // namespace experimental |
44 | |
45 | namespace snapshot_util { |
46 | |
47 | constexpr char kMetadataFilename[] = "snapshot.metadata" ; |
48 | |
49 | constexpr char kModeAuto[] = "auto" ; |
50 | constexpr char kModeWrite[] = "write" ; |
51 | constexpr char kModeRead[] = "read" ; |
52 | constexpr char kModePassthrough[] = "passthrough" ; |
53 | constexpr char kShardDirectorySuffix[] = ".shard" ; |
54 | |
55 | enum Mode { READER = 0, WRITER = 1, PASSTHROUGH = 2 }; |
56 | |
57 | // Returns the name of the "hash" directory for the given base path and hash ID. |
58 | std::string HashDirectory(const std::string& path, uint64 hash); |
59 | |
60 | // Returns the name of the "run" directory for the given base path and run ID. |
61 | std::string RunDirectory(const std::string& hash_directory, uint64 run_id); |
62 | std::string RunDirectory(const std::string& hash_directory, |
63 | const std::string& run_id); |
64 | |
65 | // Returns the name of the "shard" directory for the given base path and shard |
66 | // ID. |
67 | std::string ShardDirectory(const std::string& run_directory, int64_t shard_id); |
68 | |
69 | // Returns the checkpoint file name for the given directory and checkpoint ID. |
70 | std::string GetCheckpointFileName(const std::string& shard_directory, |
71 | const uint64 checkpoint_id); |
72 | |
73 | // This is a interface class that exposes snapshot writing functionality. |
74 | class Writer { |
75 | public: |
76 | // Creates a new writer object. |
77 | static Status Create(Env* env, const std::string& filename, |
78 | const std::string& compression_type, int version, |
79 | const DataTypeVector& dtypes, |
80 | std::unique_ptr<Writer>* out_writer); |
81 | |
82 | // Writes a vector of tensors to the snapshot writer file. |
83 | virtual Status WriteTensors(const std::vector<Tensor>& tensors) = 0; |
84 | |
85 | // Flushes any in-memory buffers to disk. |
86 | virtual Status Sync() = 0; |
87 | |
88 | // Closes and finalizes the snapshot file. All calls to any other method will |
89 | // be invalid after this call. |
90 | virtual Status Close() = 0; |
91 | |
92 | virtual ~Writer() {} |
93 | |
94 | protected: |
95 | virtual Status Initialize(tensorflow::Env* env) = 0; |
96 | }; |
97 | |
98 | // Writes snapshots with the standard TFRecord file format. |
99 | class TFRecordWriter : public Writer { |
100 | public: |
101 | TFRecordWriter(const std::string& filename, |
102 | const std::string& compression_type); |
103 | |
104 | Status WriteTensors(const std::vector<Tensor>& tensors) override; |
105 | |
106 | Status Sync() override; |
107 | |
108 | Status Close() override; |
109 | |
110 | ~TFRecordWriter() override; |
111 | |
112 | protected: |
113 | Status Initialize(tensorflow::Env* env) override; |
114 | |
115 | private: |
116 | const std::string filename_; |
117 | const std::string compression_type_; |
118 | |
119 | std::unique_ptr<WritableFile> dest_; |
120 | std::unique_ptr<io::RecordWriter> record_writer_; |
121 | }; |
122 | |
123 | // Writes snapshot with a custom (legacy) file format. |
124 | class CustomWriter : public Writer { |
125 | public: |
126 | static constexpr const size_t = sizeof(uint64); |
127 | |
128 | static constexpr const char* const kClassName = "SnapshotWriter" ; |
129 | static constexpr const char* const kWriteStringPiece = "WriteStringPiece" ; |
130 | static constexpr const char* const kWriteCord = "WriteCord" ; |
131 | static constexpr const char* const kSeparator = "::" ; |
132 | |
133 | CustomWriter(const std::string& filename, const std::string& compression_type, |
134 | const DataTypeVector& dtypes); |
135 | |
136 | Status WriteTensors(const std::vector<Tensor>& tensors) override; |
137 | |
138 | Status Sync() override; |
139 | |
140 | Status Close() override; |
141 | |
142 | ~CustomWriter() override; |
143 | |
144 | protected: |
145 | Status Initialize(tensorflow::Env* env) override; |
146 | |
147 | private: |
148 | Status WriteRecord(const StringPiece& data); |
149 | |
150 | #if defined(TF_CORD_SUPPORT) |
151 | Status WriteRecord(const absl::Cord& data); |
152 | #endif // TF_CORD_SUPPORT |
153 | |
154 | std::unique_ptr<WritableFile> dest_; |
155 | const std::string filename_; |
156 | const std::string compression_type_; |
157 | const DataTypeVector dtypes_; |
158 | // We hold zlib_dest_ because we may create a ZlibOutputBuffer and put that |
159 | // in dest_ if we want compression. ZlibOutputBuffer doesn't own the original |
160 | // dest_ and so we need somewhere to store the original one. |
161 | std::unique_ptr<WritableFile> zlib_underlying_dest_; |
162 | std::vector<bool> simple_tensor_mask_; // true for simple, false for complex. |
163 | int num_simple_ = 0; |
164 | int num_complex_ = 0; |
165 | }; |
166 | |
167 | // Interface class for reading snapshot files previous written with Writer. |
168 | class Reader { |
169 | public: |
170 | // Op kernel that creates an instance of `Reader::Dataset` needed to support |
171 | // serialization and deserialization of `Reader::Dataset`. |
172 | class DatasetOp : public DatasetOpKernel { |
173 | public: |
174 | explicit DatasetOp(OpKernelConstruction* ctx); |
175 | |
176 | protected: |
177 | void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; |
178 | |
179 | private: |
180 | DataTypeVector output_types_; |
181 | std::vector<PartialTensorShape> output_shapes_; |
182 | std::string compression_; |
183 | int64_t version_; |
184 | }; |
185 | |
186 | // Op kernel that creates an instance of `Reader::NestedDataset` needed to |
187 | // support serialization and deserialization of `Reader::NestedDataset`. |
188 | class NestedDatasetOp : public DatasetOpKernel { |
189 | public: |
190 | explicit NestedDatasetOp(OpKernelConstruction* ctx); |
191 | |
192 | protected: |
193 | void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; |
194 | |
195 | private: |
196 | DataTypeVector output_types_; |
197 | std::vector<PartialTensorShape> output_shapes_; |
198 | }; |
199 | |
200 | // Creates a new Reader object that reads data from `filename`. Note that |
201 | // the `version`, `compression_type`, and `dtypes` arguments passed into |
202 | // `Writer` and `Reader` must be the same for the reading to succeed. |
203 | static Status Create(Env* env, const std::string& filename, |
204 | const string& compression_type, int version, |
205 | const DataTypeVector& dtypes, |
206 | std::unique_ptr<Reader>* out_reader); |
207 | |
208 | // Returns a nested dataset for a set of given snapshot file names. |
209 | // |
210 | // This function takes a vector of snapshot files, and returns a nested |
211 | // dataset. Each element within the nested dataset is itself a dataset, and |
212 | // contains all the elements written out to each individual snapshot file. |
213 | static Status MakeNestedDataset(Env* env, |
214 | const std::vector<std::string>& shard_dirs, |
215 | const string& compression_type, int version, |
216 | const DataTypeVector& dtypes, |
217 | const std::vector<PartialTensorShape>& shapes, |
218 | const int64_t start_index, |
219 | DatasetBase** output); |
220 | |
221 | // Reads a vector of Tensors from the snapshot file. |
222 | virtual Status ReadTensors(std::vector<Tensor>* read_tensors) = 0; |
223 | |
224 | // Skips `num_records`. Equivalent to calling `ReadTensors` `num_records` |
225 | // times then discarding the results. |
226 | virtual Status SkipRecords(int64_t num_records); |
227 | |
228 | virtual ~Reader() {} |
229 | |
230 | protected: |
231 | virtual Status Initialize(Env* env) = 0; |
232 | |
233 | class Dataset; |
234 | class NestedDataset; |
235 | }; |
236 | |
237 | // Reads snapshots previously written with `TFRecordWriter`. |
238 | class TFRecordReader : public Reader { |
239 | public: |
240 | TFRecordReader(const std::string& filename, const string& compression_type, |
241 | const DataTypeVector& dtypes); |
242 | |
243 | Status ReadTensors(std::vector<Tensor>* read_tensors) override; |
244 | |
245 | ~TFRecordReader() override {} |
246 | |
247 | protected: |
248 | Status Initialize(Env* env) override; |
249 | |
250 | private: |
251 | std::string filename_; |
252 | std::unique_ptr<RandomAccessFile> file_; |
253 | std::unique_ptr<io::RecordReader> record_reader_; |
254 | uint64 offset_; |
255 | |
256 | const string compression_type_; |
257 | const DataTypeVector dtypes_; |
258 | }; |
259 | |
260 | // Reads snapshots previously written with `CustomWriter`. |
261 | class CustomReader : public Reader { |
262 | public: |
263 | // The reader input buffer size is deliberately large because the input reader |
264 | // will throw an error if the compressed block length cannot fit in the input |
265 | // buffer. |
266 | static constexpr const int64_t kSnappyReaderInputBufferSizeBytes = |
267 | 1 << 30; // 1 GiB |
268 | // TODO(b/148804377): Set this in a smarter fashion. |
269 | static constexpr const int64_t kSnappyReaderOutputBufferSizeBytes = |
270 | 32 << 20; // 32 MiB |
271 | static constexpr const size_t = sizeof(uint64); |
272 | |
273 | static constexpr const char* const kClassName = "SnapshotReader" ; |
274 | static constexpr const char* const kReadString = "ReadString" ; |
275 | static constexpr const char* const kReadCord = "ReadCord" ; |
276 | static constexpr const char* const kSeparator = "::" ; |
277 | |
278 | CustomReader(const std::string& filename, const string& compression_type, |
279 | const int version, const DataTypeVector& dtypes); |
280 | |
281 | Status ReadTensors(std::vector<Tensor>* read_tensors) override; |
282 | |
283 | ~CustomReader() override {} |
284 | |
285 | protected: |
286 | Status Initialize(Env* env) override; |
287 | |
288 | private: |
289 | Status ReadTensorsV0(std::vector<Tensor>* read_tensors); |
290 | |
291 | Status SnappyUncompress( |
292 | const experimental::SnapshotTensorMetadata* metadata, |
293 | std::vector<Tensor>* simple_tensors, |
294 | std::vector<std::pair<std::unique_ptr<char[]>, size_t>>* |
295 | tensor_proto_strs); |
296 | |
297 | Status ReadRecord(tstring* record); |
298 | |
299 | #if defined(TF_CORD_SUPPORT) |
300 | Status ReadRecord(absl::Cord* record); |
301 | #endif |
302 | |
303 | std::string filename_; |
304 | std::unique_ptr<RandomAccessFile> file_; |
305 | std::unique_ptr<io::InputStreamInterface> input_stream_; |
306 | const string compression_type_; |
307 | const int version_; |
308 | const DataTypeVector dtypes_; |
309 | int num_simple_ = 0; |
310 | int num_complex_ = 0; |
311 | std::vector<bool> simple_tensor_mask_; // true for simple, false for complex. |
312 | }; |
313 | |
314 | // Writes snapshot metadata to the given directory. |
315 | Status WriteMetadataFile(Env* env, const string& dir, |
316 | const experimental::SnapshotMetadataRecord* metadata); |
317 | |
318 | // Reads snapshot metadata from the given directory. |
319 | Status ReadMetadataFile(Env* env, const string& dir, |
320 | experimental::SnapshotMetadataRecord* metadata, |
321 | bool* file_exists); |
322 | |
323 | // Writes a dataset graph to the given directory. |
324 | Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash, |
325 | const GraphDef* graph); |
326 | |
327 | Status DetermineOpState(const std::string& mode_string, bool file_exists, |
328 | const experimental::SnapshotMetadataRecord* metadata, |
329 | const uint64 pending_snapshot_expiry_seconds, |
330 | Mode* mode); |
331 | |
332 | // Represents a dataset element or EOF. |
333 | struct ElementOrEOF { |
334 | std::vector<Tensor> value; |
335 | bool end_of_sequence = false; |
336 | }; |
337 | |
338 | // AsyncWriter provides API for asynchronously writing dataset elements |
339 | // (each represented as a vector of tensors) to a file. |
340 | // |
341 | // The expected use of this API is: |
342 | // |
343 | // std::unique_ptr<AsyncWriter> writer = absl_make_unique<AsyncWriter>(...); |
344 | // |
345 | // while (data_available()) { |
346 | // std::vector<Tensor> data = read_data() |
347 | // writer->Write(data); |
348 | // } |
349 | // writer->SignalEOF(); |
350 | // writer = nullptr; // This will block until writes are flushed. |
351 | class AsyncWriter { |
352 | public: |
353 | explicit AsyncWriter(Env* env, int64_t file_index, |
354 | const std::string& shard_directory, uint64 checkpoint_id, |
355 | const std::string& compression, int64_t version, |
356 | const DataTypeVector& output_types, |
357 | std::function<void(Status)> done); |
358 | |
359 | // Writes the given tensors. The method is non-blocking and returns without |
360 | // waiting for the element to be written. |
361 | void Write(const std::vector<Tensor>& tensors) TF_LOCKS_EXCLUDED(mu_); |
362 | |
363 | // Signals the end of input. The method is non-blocking and returns without |
364 | // waiting for the writer to be closed. |
365 | void SignalEOF() TF_LOCKS_EXCLUDED(mu_); |
366 | |
367 | private: |
368 | void Consume(ElementOrEOF* be) TF_LOCKS_EXCLUDED(mu_); |
369 | bool ElementAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); |
370 | Status WriterThread(Env* env, const std::string& shard_directory, |
371 | uint64 checkpoint_id, const std::string& compression, |
372 | int64_t version, DataTypeVector output_types); |
373 | |
374 | mutex mu_; |
375 | std::deque<ElementOrEOF> deque_ TF_GUARDED_BY(mu_); |
376 | |
377 | // This has to be last. During destruction, we need to make sure that the |
378 | // Thread object is destroyed first as its destructor blocks on thread |
379 | // completion. If there are other member variables after this, they may get |
380 | // destroyed first before the thread finishes, potentially causing the |
381 | // thread to access invalid memory. |
382 | std::unique_ptr<Thread> thread_; |
383 | }; |
384 | |
385 | } // namespace snapshot_util |
386 | } // namespace data |
387 | } // namespace tensorflow |
388 | |
389 | #endif // TENSORFLOW_CORE_DATA_SNAPSHOT_UTILS_H_ |
390 | |