1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// A tensor bundle is a set of immutable persistent files storing a set of named
17// tensors. It is designed for checkpointing TensorFlow tensors.
18//
19// The paths of the managed files share a common prefix; e.g., with the prefix:
20// /fs/model/train/ckpt-step/ckpt
21//
22// the bundle may contain a metadata file, and sharded data files:
23// /fs/model/train/ckpt-step/
24// ckpt.index
25// ckpt.data-00000-of-00020
26// ckpt.data-00001-of-00020
27// ...
28// ckpt.data-00019-of-00020
29//
30// The ".index" file is a string-string immutable table
31// (tensorflow::table::Table). Each key is a name of a tensor and its value is
32// a serialized BundleEntryProto. Each BundleEntryProto describes the metadata
33// of a tensor: which of the "data" files contains the content of a tensor, the
34// offset into that file, checksum, some auxiliary data, etc.
35//
36// A tensor bundle can be accessed randomly using a BundleReader. Usage:
37//
38// BundleReader reader(env, "/fs/model/train/ckpt-step/ckpt");
39// reader.Lookup("name", &tensor);
40//
41// A tensor bundle can be built using BundleWriter. Each BundleWriter builds a
42// single data file bundle. Multiple bundles can then be merged by
43// MergeBundles() without reading and writing large chunk of data: it reads the
44// metadata files and outputs a single merged metadata. Typical usage:
45//
46// worker 0:
47// BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker0-step");
48// writer.Add(...); // Adds the tensors on this worker.
49// writer.Finish(); // Flushes.
50// worker 1:
51// BundleWriter writer(env, "/fs/model/train/ckpt-step/tmp/worker1-step");
52// writer.Add(...);
53// writer.Finish();
54// worker 2:
55// MergeBundles(env,
56// {"/fs/model/train/ckpt-step/tmp/worker0-step",
57// "/fs/model/train/ckpt-step/tmp/worker1-step"},
58// "/fs/model/train/ckpt-step/ckpt" /* merged prefix */);
59//
60
61#ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
62#define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
63
64#include <map>
65#include <string>
66#include <unordered_map>
67
68#include "absl/algorithm/container.h"
69#include "absl/container/flat_hash_map.h"
70#include "absl/functional/function_ref.h"
71#include "tensorflow/core/framework/tensor.h"
72#include "tensorflow/core/framework/tensor_shape.h"
73#include "tensorflow/core/framework/tensor_slice.h"
74#include "tensorflow/core/lib/core/status.h"
75#include "tensorflow/core/lib/gtl/array_slice.h"
76#include "tensorflow/core/lib/io/cache.h"
77#include "tensorflow/core/lib/io/inputbuffer.h"
78#include "tensorflow/core/lib/io/table.h"
79#include "tensorflow/core/platform/cord.h"
80#include "tensorflow/core/platform/env.h"
81#include "tensorflow/core/platform/file_system.h"
82#include "tensorflow/core/platform/macros.h"
83#include "tensorflow/core/platform/types.h"
84#include "tensorflow/core/protobuf/tensor_bundle.pb.h"
85#include "tensorflow/core/util/tensor_bundle/naming.h"
86#include "tensorflow/core/util/tensor_slice_set.h"
87
88namespace tensorflow {
89
90class FileOutputBuffer;
91
92// Versioning of the tensor bundle format.
93// Follows the same rules as 3p/tf/core/public/version.h.
94//
95// History:
96// 0. Any tensor bundles produced before this field was added.
97// 1. Added this field (2016-09-14).
98extern const int kTensorBundleMinProducer;
99extern const int kTensorBundleMinConsumer;
100extern const int kTensorBundleVersion;
101
102// The empty string, hence always the first key in the metadata table. Its
103// corresponding value is a BundleHeaderProto.
104extern const char* const kHeaderEntryKey;
105
106// Builds a string-string table of tensor names to BundleEntryProto (metadata).
107//
108// On construction, attempts to create a directory given by the dirname of
109// "prefix", so "status()" must be checked before calling any member functions.
110//
111// All threads accessing the same BundleWriter must synchronize.
112class BundleWriter {
113 public:
114 struct Options {
115 Options() {}
116 // Alignment, in bytes, for tensor data.
117 // Must be >= 1. The default size of 1 densely packs tensors.
118 int data_alignment{1};
119 };
120 BundleWriter(Env* env, StringPiece prefix,
121 const Options& options = Options());
122
123 // Adds the tensor "val" under key "key".
124 // Across calls "key" must be unique but can be added in any order.
125 Status Add(StringPiece key, const Tensor& val);
126
127 // Partitioned variables support.
128 // A slice of a full tensor is stored in two entries in the metadata table:
129 //
130 // full_tensor_key -> BundleEntryProto, describing all stored slices
131 // of this full tensor. Does not append to the data
132 // file.
133 // encoded slice key -> BundleEntryProto, describing one particular slice.
134 // Appends values of this slice to the data file.
135 //
136 // Slices of a full tensor can be added in any order.
137 //
138 // If a full tensor has slices placed on N devices and N BundleWriter's are
139 // concurrently used, the caller must use MergeBundles() to ensure that a
140 // consistent entry for "full_tensor_key" is produced.
141 //
142 // Returns an error if the same slice is added the second time.
143 Status AddSlice(StringPiece full_tensor_key,
144 const TensorShape& full_tensor_shape,
145 const TensorSlice& slice_spec, const Tensor& slice_tensor);
146
147 // Finishes the writer and flushes.
148 Status Finish() TF_MUST_USE_RESULT;
149
150 Status status() const { return status_; }
151
152 private:
153 Env* const env_; // Not owned.
154 const Options options_;
155 const string prefix_;
156 string metadata_path_;
157 string data_path_;
158 bool use_temp_file_;
159 std::unique_ptr<FileOutputBuffer> out_;
160 int64_t size_; // Number of bytes written into out_.
161 std::map<string, BundleEntryProto> entries_;
162 Status status_;
163
164 TF_DISALLOW_COPY_AND_ASSIGN(BundleWriter);
165};
166
167// Merges a set of bundles (given their prefixes) into a single bundle with the
168// given "merged_prefix". The merged metadata is guaranteed to be consistent.
169//
170// If there are N bundles in "prefixes", during the merge the data files will be
171// renamed to contain a proper sharded file spec, with num_shards set to the sum
172// of num_shards across the N input bundles.
173//
174// The caller should only rely on the metadata file of the merged bundle to
175// query information about a tensor. In particular, this function does not
176// guarantee not to re-order the input data files.
177//
178// Once merged, makes a best effort to delete the old metadata files.
179// Returns OK iff all bundles are successfully merged.
180//
181// "allow_missing_files": If set to true, merges "prefixes" as long as
182// at least one file exists. (Defaults to false.)
183//
184// Returns an InvalidArgumentError when "allow_missing_files" is set to true
185// and all data files named in "prefixes" do not exist.
186//
187// Returns a NotFoundError when "allow_missing_files" is set to false and
188// any data file named in "prefixes" does not exist.
189Status MergeBundles(Env* env, gtl::ArraySlice<tstring> prefixes,
190 StringPiece merged_prefix,
191 bool allow_missing_files = false);
192
193// On construction, silently attempts to read the metadata associated with
194// "prefix". If caller intends to call any function afterwards, "status()"
195// must be checked.
196// All threads accessing the same BundleReader must synchronize.
197class BundleReader {
198 public:
199 BundleReader(Env* const env, StringPiece prefix,
200 bool enable_multi_threading_for_testing = false);
201 ~BundleReader();
202
203 // Is ok() iff the reader construction is successful (completed the read of
204 // the metadata).
205 Status status() const { return status_; }
206
207 // Queries whether the bundle contains an entry keyed by "key". Calls Seek()
208 // internally, so this call invalidates the reader's current position.
209 // REQUIRES: status().ok()
210 bool Contains(StringPiece key);
211
212 // Sorts a `container` of tensors to read such that when `Seek(key)` is called
213 // on the elements of the sorted container, the underlying file access is
214 // sequential. Sorting can greatly improve overall read speed.
215 //
216 // `get_key` should be a functon that when passed an element in `container`,
217 // returns the `key` of the tensor.
218 //
219 // REQUIRES: status().ok()
220 template <class T>
221 Status SortForSequentialAccess(std::vector<T>& container,
222 absl::FunctionRef<string(const T&)> get_key);
223
224 // Looks up the dtype and the shape of the tensor keyed by "key".
225 // REQUIRES: status().ok()
226 Status LookupDtypeAndShape(StringPiece key, DataType* dtype,
227 TensorShape* shape) TF_MUST_USE_RESULT;
228
229 // Looks up the shape of the tensor keyed by "key".
230 // Clears "shape" if not found.
231 // REQUIRES: status().ok()
232 Status LookupTensorShape(StringPiece key,
233 TensorShape* shape) TF_MUST_USE_RESULT;
234
235 // Looks up the tensor keyed by "key". If "key" refers to a partitioned
236 // tensor, attempts to look up the full contents using all stored slices.
237 //
238 // Caller must make sure "val" has the same shape and dtype as the
239 // corresponding contents, so that its buffer can be filled without needing
240 // extra allocation. These can be queried via "LookupDtypeAndShape()".
241 //
242 // On error, "val" may contain nonsense data. Returns a NotFound error if
243 // tensor keyed by "key" does not exist in this bundle.
244 //
245 // Validates the stored crc32c checksum against the restored bytes.
246 // REQUIRES: status().ok()
247 Status Lookup(StringPiece key, Tensor* val) TF_MUST_USE_RESULT;
248
249 // Looks up the tensor pointed to by the internal iterator.
250 //
251 // On error, "val" may contain nonsense data.
252 //
253 // Validates the stored crc32c checksum against the restored bytes.
254 // REQUIRES: status().ok() && Valid()
255 Status ReadCurrent(Tensor* val) TF_MUST_USE_RESULT;
256
257 // Looks up the slices of the tensor keyed by "key". On OK, "slices"
258 // is non-empty if and only if the tensor is a partitioned tensor.
259 //
260 // Warning - there is no guaranteed ordering for the returned slices, so
261 // a slice with a larger start index in some dimension could come before
262 // another slice with a smaller start index in the same dimension.
263 // REQUIRES: status().ok()
264 Status LookupTensorSlices(StringPiece key, std::vector<TensorSlice>* slices)
265 TF_MUST_USE_RESULT;
266
267 // Looks up a specific slice of a partitioned tensor.
268 // It is only required that the stored slices cover the requested slice,
269 // namely "slice_spec" is a subset of the union of the stored slices.
270 // REQUIRES: status().ok()
271 Status LookupSlice(StringPiece full_tensor_key, const TensorSlice& slice_spec,
272 Tensor* val) TF_MUST_USE_RESULT;
273
274 // Seeks to the first position in the bundle whose key is no less than "key".
275 // REQUIRES: status().ok()
276 void Seek(StringPiece key) { return iter_->Seek(key); }
277 // Moves to the next position in the bundle.
278 // REQUIRES: status().ok()
279 void Next() const { iter_->Next(); }
280 // Returns true iff the reader is positioned to a key/val pair.
281 // REQUIRES: status().ok()
282 bool Valid() const { return iter_->Valid(); }
283
284 // Returns the key at the current position.
285 // REQUIRES: status().ok() && Valid()
286 StringPiece key() const { return iter_->key(); }
287 // Returns the raw value at the current position.
288 // REQUIRES: status().ok() && Valid()
289 StringPiece value() const { return iter_->value(); }
290
291 string DebugString();
292
293 private:
294 // Seeks for "key" and reads the metadata proto.
295 // On non-OK return, clears "entry" for the caller.
296 // REQUIRES: status().ok()
297 Status GetBundleEntryProto(StringPiece key,
298 BundleEntryProto* entry) TF_MUST_USE_RESULT;
299
300 // Reads the tensor value described by the metadata proto "entry".
301 // Usage for "val" follows the comment of "Lookup()".
302 Status GetValue(const BundleEntryProto& entry,
303 Tensor* val) TF_MUST_USE_RESULT;
304
305 // Reads the slice described by "slice_spec". The corresponding full tensor
306 // has key "ful_tensor_key" and metadata proto "full_tensor_entry".
307 // REQUIRES: full_tensor_entry.slices_size() > 0
308 Status GetSliceValue(StringPiece full_tensor_key,
309 const BundleEntryProto& full_tensor_entry,
310 const TensorSlice& slice_spec,
311 Tensor* val) TF_MUST_USE_RESULT;
312
313 Env* env_; // Not owned.
314 const string prefix_;
315
316 Status status_;
317 RandomAccessFile* metadata_; // Owned.
318 table::Table* table_;
319 table::Cache* index_cache_;
320 table::Iterator* iter_;
321 // Owned the InputBuffer objects and their underlying RandomAccessFile's.
322 std::unordered_map<int32, io::InputBuffer*> data_;
323
324 // Maps each partitioned tensor's key to its stored slices (represented in a
325 // TensorSliceSet). Populated on-demand.
326 std::unordered_map<string, checkpoint::TensorSliceSet*> tensor_slices_;
327
328 // Expected number of data file shards in the bundle. Extracted by reading
329 // the header entry in the metadata table.
330 int num_shards_;
331
332 // Flag that this class sets to true when the endianness of the target bundle
333 // differs from that of the current system's processor architecture.
334 bool need_to_swap_bytes_;
335
336 friend class TensorBundleAlignmentTest; // For testing data alignment.
337
338 bool enable_multi_threading_for_testing_ = false;
339
340 TF_DISALLOW_COPY_AND_ASSIGN(BundleReader);
341};
342
343// A buffering wrapper for a WritableFile. Useful if the caller wishes to issue
344// small writes to a file (e.g. writing out a list of small varints).
345// External synchronization must be used in the presence of concurrent callers.
346class FileOutputBuffer {
347 public:
348 FileOutputBuffer(WritableFile* file, size_t buffer_size);
349 ~FileOutputBuffer();
350
351 // Buffered append.
352 Status Append(StringPiece data);
353
354 // Returns the running crc32c checksum of all currently appended bytes.
355 uint32 crc32c() { return crc32c_; }
356 // Clears the running crc32c checksum.
357 void clear_crc32c() { crc32c_ = 0; }
358
359 // Appends the buffered data, then closes the underlying file.
360 Status Close();
361
362 private:
363 // Appends the buffered data to the underlying file. Does NOT flush the file.
364 Status FlushBuffer(bool closing);
365
366 WritableFile* file_; // Owned.
367
368 // buffer_ptr_[0, position_) holds the buffered data not yet appended to the
369 // underlying file.
370 size_t position_;
371 const size_t buffer_size_;
372 char* buffer_ptr_;
373
374 // Checksum of all appended bytes since construction or last clear_crc32c().
375 uint32 crc32c_ = 0;
376};
377
378template <class T>
379Status BundleReader::SortForSequentialAccess(
380 std::vector<T>& container, absl::FunctionRef<string(const T&)> get_key) {
381 struct FileOffset {
382 int32_t shard_id;
383 int64_t offset;
384 };
385 absl::flat_hash_map<string, FileOffset> file_offsets;
386 for (const T& element : container) {
387 BundleEntryProto entry;
388 TF_RETURN_IF_ERROR(GetBundleEntryProto(get_key(element), &entry));
389 file_offsets[get_key(element)] = {entry.shard_id(), entry.offset()};
390 }
391 absl::c_sort(container, [&get_key, &file_offsets](const T& a, const T& b) {
392 const FileOffset& file_offset_a = file_offsets[get_key(a)];
393 const FileOffset& file_offset_b = file_offsets[get_key(b)];
394 if (file_offset_a.shard_id == file_offset_b.shard_id) {
395 return file_offset_a.offset < file_offset_b.offset;
396 } else {
397 return file_offset_a.shard_id < file_offset_b.shard_id;
398 }
399 });
400 return OkStatus();
401}
402
403} // namespace tensorflow
404
405#endif // TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
406