1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // A tensor bundle is a set of immutable persistent files storing a set of named |
17 | // tensors. It is designed for checkpointing TensorFlow tensors. |
18 | // |
19 | // The paths of the managed files share a common prefix; e.g., with the prefix: |
20 | // /fs/model/train/ckpt-step/ckpt |
21 | // |
22 | // the bundle may contain a metadata file, and sharded data files: |
23 | // /fs/model/train/ckpt-step/ |
24 | // ckpt.index |
25 | // ckpt.data-00000-of-00020 |
26 | // ckpt.data-00001-of-00020 |
27 | // ... |
28 | // ckpt.data-00019-of-00020 |
29 | // |
30 | // The ".index" file is a string-string immutable table |
31 | // (tensorflow::table::Table). Each key is a name of a tensor and its value is |
32 | // a serialized BundleEntryProto. Each BundleEntryProto describes the metadata |
33 | // of a tensor: which of the "data" files contains the content of a tensor, the |
34 | // offset into that file, checksum, some auxiliary data, etc. |
35 | // |
36 | // A tensor bundle can be accessed randomly using a BundleReader. Usage: |
37 | // |
38 | // BundleReader reader(env, "/fs/model/train/ckpt-step/ckpt"); |
39 | // reader.Lookup("name", &tensor); |
40 | // |
41 | // A tensor bundle can be built using BundleWriter. Each BundleWriter builds a |
42 | // single data file bundle. Multiple bundles can then be merged by |
43 | // MergeBundles() without reading and writing large chunk of data: it reads the |
44 | // metadata files and outputs a single merged metadata. Typical usage: |
45 | // |
46 | // worker 0: |
47 | // BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker0-step"); |
48 | // writer.Add(...); // Adds the tensors on this worker. |
49 | // writer.Finish(); // Flushes. |
50 | // worker 1: |
51 | // BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker1-step"); |
52 | // writer.Add(...); |
53 | // writer.Finish(); |
54 | // worker 2: |
55 | // MergeBundles(env, |
56 | // {"/fs/model/train/ckpt-step/tmp/worker0-step", |
57 | // "/fs/model/train/ckpt-step/tmp/worker1-step"}, |
58 | // "/fs/model/train/ckpt-step/ckpt" /* merged prefix */); |
59 | // |
60 | |
61 | #ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ |
62 | #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ |
63 | |
64 | #include <map> |
65 | #include <string> |
66 | #include <unordered_map> |
67 | |
68 | #include "absl/algorithm/container.h" |
69 | #include "absl/container/flat_hash_map.h" |
70 | #include "absl/functional/function_ref.h" |
71 | #include "tensorflow/core/framework/tensor.h" |
72 | #include "tensorflow/core/framework/tensor_shape.h" |
73 | #include "tensorflow/core/framework/tensor_slice.h" |
74 | #include "tensorflow/core/lib/core/status.h" |
75 | #include "tensorflow/core/lib/gtl/array_slice.h" |
76 | #include "tensorflow/core/lib/io/cache.h" |
77 | #include "tensorflow/core/lib/io/inputbuffer.h" |
78 | #include "tensorflow/core/lib/io/table.h" |
79 | #include "tensorflow/core/platform/cord.h" |
80 | #include "tensorflow/core/platform/env.h" |
81 | #include "tensorflow/core/platform/file_system.h" |
82 | #include "tensorflow/core/platform/macros.h" |
83 | #include "tensorflow/core/platform/types.h" |
84 | #include "tensorflow/core/protobuf/tensor_bundle.pb.h" |
85 | #include "tensorflow/core/util/tensor_bundle/naming.h" |
86 | #include "tensorflow/core/util/tensor_slice_set.h" |
87 | |
88 | namespace tensorflow { |
89 | |
90 | class FileOutputBuffer; |
91 | |
92 | // Versioning of the tensor bundle format. |
93 | // Follows the same rules as 3p/tf/core/public/version.h. |
94 | // |
95 | // History: |
96 | // 0. Any tensor bundles produced before this field was added. |
97 | // 1. Added this field (2016-09-14). |
98 | extern const int kTensorBundleMinProducer; |
99 | extern const int kTensorBundleMinConsumer; |
100 | extern const int kTensorBundleVersion; |
101 | |
102 | // The empty string, hence always the first key in the metadata table. Its |
103 | // corresponding value is a BundleHeaderProto. |
104 | extern const char* const ; |
105 | |
106 | // Builds a string-string table of tensor names to BundleEntryProto (metadata). |
107 | // |
108 | // On construction, attempts to create a directory given by the dirname of |
109 | // "prefix", so "status()" must be checked before calling any member functions. |
110 | // |
111 | // All threads accessing the same BundleWriter must synchronize. |
112 | class BundleWriter { |
113 | public: |
114 | struct Options { |
115 | Options() {} |
116 | // Alignment, in bytes, for tensor data. |
117 | // Must be >= 1. The default size of 1 densely packs tensors. |
118 | int data_alignment{1}; |
119 | }; |
120 | BundleWriter(Env* env, StringPiece prefix, |
121 | const Options& options = Options()); |
122 | |
123 | // Adds the tensor "val" under key "key". |
124 | // Across calls "key" must be unique but can be added in any order. |
125 | Status Add(StringPiece key, const Tensor& val); |
126 | |
127 | // Partitioned variables support. |
128 | // A slice of a full tensor is stored in two entries in the metadata table: |
129 | // |
130 | // full_tensor_key -> BundleEntryProto, describing all stored slices |
131 | // of this full tensor. Does not append to the data |
132 | // file. |
133 | // encoded slice key -> BundleEntryProto, describing one particular slice. |
134 | // Appends values of this slice to the data file. |
135 | // |
136 | // Slices of a full tensor can be added in any order. |
137 | // |
138 | // If a full tensor has slices placed on N devices and N BundleWriter's are |
139 | // concurrently used, the caller must use MergeBundles() to ensure that a |
140 | // consistent entry for "full_tensor_key" is produced. |
141 | // |
142 | // Returns an error if the same slice is added the second time. |
143 | Status AddSlice(StringPiece full_tensor_key, |
144 | const TensorShape& full_tensor_shape, |
145 | const TensorSlice& slice_spec, const Tensor& slice_tensor); |
146 | |
147 | // Finishes the writer and flushes. |
148 | Status Finish() TF_MUST_USE_RESULT; |
149 | |
150 | Status status() const { return status_; } |
151 | |
152 | private: |
153 | Env* const env_; // Not owned. |
154 | const Options options_; |
155 | const string prefix_; |
156 | string metadata_path_; |
157 | string data_path_; |
158 | bool use_temp_file_; |
159 | std::unique_ptr<FileOutputBuffer> out_; |
160 | int64_t size_; // Number of bytes written into out_. |
161 | std::map<string, BundleEntryProto> entries_; |
162 | Status status_; |
163 | |
164 | TF_DISALLOW_COPY_AND_ASSIGN(BundleWriter); |
165 | }; |
166 | |
167 | // Merges a set of bundles (given their prefixes) into a single bundle with the |
168 | // given "merged_prefix". The merged metadata is guaranteed to be consistent. |
169 | // |
170 | // If there are N bundles in "prefixes", during the merge the data files will be |
171 | // renamed to contain a proper sharded file spec, with num_shards set to the sum |
172 | // of num_shards across the N input bundles. |
173 | // |
174 | // The caller should only rely on the metadata file of the merged bundle to |
175 | // query information about a tensor. In particular, this function does not |
176 | // guarantee not to re-order the input data files. |
177 | // |
178 | // Once merged, makes a best effort to delete the old metadata files. |
179 | // Returns OK iff all bundles are successfully merged. |
180 | // |
181 | // "allow_missing_files": If set to true, merges "prefixes" as long as |
182 | // at least one file exists. (Defaults to false.) |
183 | // |
184 | // Returns an InvalidArgumentError when "allow_missing_files" is set to true |
185 | // and all data files named in "prefixes" do not exist. |
186 | // |
187 | // Returns a NotFoundError when "allow_missing_files" is set to false and |
188 | // any data file named in "prefixes" does not exist. |
189 | Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes, |
190 | StringPiece merged_prefix, |
191 | bool allow_missing_files = false); |
192 | |
193 | // On construction, silently attempts to read the metadata associated with |
194 | // "prefix". If caller intends to call any function afterwards, "status()" |
195 | // must be checked. |
196 | // All threads accessing the same BundleReader must synchronize. |
197 | class BundleReader { |
198 | public: |
199 | BundleReader(Env* const env, StringPiece prefix, |
200 | bool enable_multi_threading_for_testing = false); |
201 | ~BundleReader(); |
202 | |
203 | // Is ok() iff the reader construction is successful (completed the read of |
204 | // the metadata). |
205 | Status status() const { return status_; } |
206 | |
207 | // Queries whether the bundle contains an entry keyed by "key". Calls Seek() |
208 | // internally, so this call invalidates the reader's current position. |
209 | // REQUIRES: status().ok() |
210 | bool Contains(StringPiece key); |
211 | |
212 | // Sorts a `container` of tensors to read such that when `Seek(key)` is called |
213 | // on the elements of the sorted container, the underlying file access is |
214 | // sequential. Sorting can greatly improve overall read speed. |
215 | // |
216 | // `get_key` should be a functon that when passed an element in `container`, |
217 | // returns the `key` of the tensor. |
218 | // |
219 | // REQUIRES: status().ok() |
220 | template <class T> |
221 | Status SortForSequentialAccess(std::vector<T>& container, |
222 | absl::FunctionRef<string(const T&)> get_key); |
223 | |
224 | // Looks up the dtype and the shape of the tensor keyed by "key". |
225 | // REQUIRES: status().ok() |
226 | Status LookupDtypeAndShape(StringPiece key, DataType* dtype, |
227 | TensorShape* shape) TF_MUST_USE_RESULT; |
228 | |
229 | // Looks up the shape of the tensor keyed by "key". |
230 | // Clears "shape" if not found. |
231 | // REQUIRES: status().ok() |
232 | Status LookupTensorShape(StringPiece key, |
233 | TensorShape* shape) TF_MUST_USE_RESULT; |
234 | |
235 | // Looks up the tensor keyed by "key". If "key" refers to a partitioned |
236 | // tensor, attempts to look up the full contents using all stored slices. |
237 | // |
238 | // Caller must make sure "val" has the same shape and dtype as the |
239 | // corresponding contents, so that its buffer can be filled without needing |
240 | // extra allocation. These can be queried via "LookupDtypeAndShape()". |
241 | // |
242 | // On error, "val" may contain nonsense data. Returns a NotFound error if |
243 | // tensor keyed by "key" does not exist in this bundle. |
244 | // |
245 | // Validates the stored crc32c checksum against the restored bytes. |
246 | // REQUIRES: status().ok() |
247 | Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT; |
248 | |
249 | // Looks up the tensor pointed to by the internal iterator. |
250 | // |
251 | // On error, "val" may contain nonsense data. |
252 | // |
253 | // Validates the stored crc32c checksum against the restored bytes. |
254 | // REQUIRES: status().ok() && Valid() |
255 | Status ReadCurrent(Tensor* val) TF_MUST_USE_RESULT; |
256 | |
257 | // Looks up the slices of the tensor keyed by "key". On OK, "slices" |
258 | // is non-empty if and only if the tensor is a partitioned tensor. |
259 | // |
260 | // Warning - there is no guaranteed ordering for the returned slices, so |
261 | // a slice with a larger start index in some dimension could come before |
262 | // another slice with a smaller start index in the same dimension. |
263 | // REQUIRES: status().ok() |
264 | Status LookupTensorSlices(StringPiece key, std::vector<TensorSlice>* slices) |
265 | TF_MUST_USE_RESULT; |
266 | |
267 | // Looks up a specific slice of a partitioned tensor. |
268 | // It is only required that the stored slices cover the requested slice, |
269 | // namely "slice_spec" is a subset of the union of the stored slices. |
270 | // REQUIRES: status().ok() |
271 | Status LookupSlice(StringPiece full_tensor_key, const TensorSlice& slice_spec, |
272 | Tensor* val) TF_MUST_USE_RESULT; |
273 | |
274 | // Seeks to the first position in the bundle whose key is no less than "key". |
275 | // REQUIRES: status().ok() |
276 | void Seek(StringPiece key) { return iter_->Seek(key); } |
277 | // Moves to the next position in the bundle. |
278 | // REQUIRES: status().ok() |
279 | void Next() const { iter_->Next(); } |
280 | // Returns true iff the reader is positioned to a key/val pair. |
281 | // REQUIRES: status().ok() |
282 | bool Valid() const { return iter_->Valid(); } |
283 | |
284 | // Returns the key at the current position. |
285 | // REQUIRES: status().ok() && Valid() |
286 | StringPiece key() const { return iter_->key(); } |
287 | // Returns the raw value at the current position. |
288 | // REQUIRES: status().ok() && Valid() |
289 | StringPiece value() const { return iter_->value(); } |
290 | |
291 | string DebugString(); |
292 | |
293 | private: |
294 | // Seeks for "key" and reads the metadata proto. |
295 | // On non-OK return, clears "entry" for the caller. |
296 | // REQUIRES: status().ok() |
297 | Status GetBundleEntryProto(StringPiece key, |
298 | BundleEntryProto* entry) TF_MUST_USE_RESULT; |
299 | |
300 | // Reads the tensor value described by the metadata proto "entry". |
301 | // Usage for "val" follows the comment of "Lookup()". |
302 | Status GetValue(const BundleEntryProto& entry, |
303 | Tensor* val) TF_MUST_USE_RESULT; |
304 | |
305 | // Reads the slice described by "slice_spec". The corresponding full tensor |
306 | // has key "ful_tensor_key" and metadata proto "full_tensor_entry". |
307 | // REQUIRES: full_tensor_entry.slices_size() > 0 |
308 | Status GetSliceValue(StringPiece full_tensor_key, |
309 | const BundleEntryProto& full_tensor_entry, |
310 | const TensorSlice& slice_spec, |
311 | Tensor* val) TF_MUST_USE_RESULT; |
312 | |
313 | Env* env_; // Not owned. |
314 | const string prefix_; |
315 | |
316 | Status status_; |
317 | RandomAccessFile* metadata_; // Owned. |
318 | table::Table* table_; |
319 | table::Cache* index_cache_; |
320 | table::Iterator* iter_; |
321 | // Owned the InputBuffer objects and their underlying RandomAccessFile's. |
322 | std::unordered_map<int32, io::InputBuffer*> data_; |
323 | |
324 | // Maps each partitioned tensor's key to its stored slices (represented in a |
325 | // TensorSliceSet). Populated on-demand. |
326 | std::unordered_map<string, checkpoint::TensorSliceSet*> tensor_slices_; |
327 | |
328 | // Expected number of data file shards in the bundle. Extracted by reading |
329 | // the header entry in the metadata table. |
330 | int num_shards_; |
331 | |
332 | // Flag that this class sets to true when the endianness of the target bundle |
333 | // differs from that of the current system's processor architecture. |
334 | bool need_to_swap_bytes_; |
335 | |
336 | friend class TensorBundleAlignmentTest; // For testing data alignment. |
337 | |
338 | bool enable_multi_threading_for_testing_ = false; |
339 | |
340 | TF_DISALLOW_COPY_AND_ASSIGN(BundleReader); |
341 | }; |
342 | |
343 | // A buffering wrapper for a WritableFile. Useful if the caller wishes to issue |
344 | // small writes to a file (e.g. writing out a list of small varints). |
345 | // External synchronization must be used in the presence of concurrent callers. |
346 | class FileOutputBuffer { |
347 | public: |
348 | FileOutputBuffer(WritableFile* file, size_t buffer_size); |
349 | ~FileOutputBuffer(); |
350 | |
351 | // Buffered append. |
352 | Status Append(StringPiece data); |
353 | |
354 | // Returns the running crc32c checksum of all currently appended bytes. |
355 | uint32 crc32c() { return crc32c_; } |
356 | // Clears the running crc32c checksum. |
357 | void clear_crc32c() { crc32c_ = 0; } |
358 | |
359 | // Appends the buffered data, then closes the underlying file. |
360 | Status Close(); |
361 | |
362 | private: |
363 | // Appends the buffered data to the underlying file. Does NOT flush the file. |
364 | Status FlushBuffer(bool closing); |
365 | |
366 | WritableFile* file_; // Owned. |
367 | |
368 | // buffer_ptr_[0, position_) holds the buffered data not yet appended to the |
369 | // underlying file. |
370 | size_t position_; |
371 | const size_t buffer_size_; |
372 | char* buffer_ptr_; |
373 | |
374 | // Checksum of all appended bytes since construction or last clear_crc32c(). |
375 | uint32 crc32c_ = 0; |
376 | }; |
377 | |
378 | template <class T> |
379 | Status BundleReader::SortForSequentialAccess( |
380 | std::vector<T>& container, absl::FunctionRef<string(const T&)> get_key) { |
381 | struct FileOffset { |
382 | int32_t shard_id; |
383 | int64_t offset; |
384 | }; |
385 | absl::flat_hash_map<string, FileOffset> file_offsets; |
386 | for (const T& element : container) { |
387 | BundleEntryProto entry; |
388 | TF_RETURN_IF_ERROR(GetBundleEntryProto(get_key(element), &entry)); |
389 | file_offsets[get_key(element)] = {entry.shard_id(), entry.offset()}; |
390 | } |
391 | absl::c_sort(container, [&get_key, &file_offsets](const T& a, const T& b) { |
392 | const FileOffset& file_offset_a = file_offsets[get_key(a)]; |
393 | const FileOffset& file_offset_b = file_offsets[get_key(b)]; |
394 | if (file_offset_a.shard_id == file_offset_b.shard_id) { |
395 | return file_offset_a.offset < file_offset_b.offset; |
396 | } else { |
397 | return file_offset_a.shard_id < file_offset_b.shard_id; |
398 | } |
399 | }); |
400 | return OkStatus(); |
401 | } |
402 | |
403 | } // namespace tensorflow |
404 | |
405 | #endif // TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_ |
406 | |