1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// The utility to read checkpoints for google brain tensor ops and v3
17// checkpoints for dist_belief.
18
19#ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_
20#define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_
21
22#include <unordered_map>
23
24#include <vector>
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/tensor_shape.h"
27#include "tensorflow/core/framework/tensor_slice.h"
28#include "tensorflow/core/framework/types.pb.h"
29#include "tensorflow/core/lib/core/status.h"
30#include "tensorflow/core/lib/core/stringpiece.h"
31#include "tensorflow/core/lib/gtl/map_util.h"
32#include "tensorflow/core/platform/logging.h"
33#include "tensorflow/core/platform/macros.h"
34#include "tensorflow/core/platform/mutex.h"
35#include "tensorflow/core/platform/protobuf.h"
36#include "tensorflow/core/platform/types.h"
37#include "tensorflow/core/util/saved_tensor_slice.pb.h"
38#include "tensorflow/core/util/saved_tensor_slice_util.h"
39#include "tensorflow/core/util/tensor_slice_set.h"
40#include "tensorflow/core/util/tensor_slice_util.h"
41
42namespace tensorflow {
43
44namespace checkpoint {
45
46// The reader reads in all the meta data about all the tensor slices. Then it
47// will try to read the relevant data on-demand to produce the data for the
48// slices needed.
49// NOTE(yangke): another way to do this is to first load a list of the tensor
50// slices needed and then just selectively read some of the meta data. That
51// might optimize the loading but makes the logic a bit more complicated. We
52// might want to revisit that.
53// TODO(yangke): consider moving to TensorProto.
54class TensorSliceReader {
55 public:
56 // Abstract interface for reading data out of a tensor slice checkpoint file
57 class Table {
58 public:
59 virtual ~Table();
60 virtual bool Get(const string& key, string* value) = 0;
61 };
62 typedef std::function<Status(const string&, Table**)> OpenTableFunction;
63
64 static constexpr int kLoadAllShards = -1;
65 TensorSliceReader(const string& filepattern);
66 TensorSliceReader(const string& filepattern, OpenTableFunction open_function);
67 TensorSliceReader(const string& filepattern, OpenTableFunction open_function,
68 int preferred_shard);
69 virtual ~TensorSliceReader();
70
71 // Get the filename this reader is attached to.
72 const string& filepattern() const { return filepattern_; }
73
74 // Get the number of files matched.
75 int num_files() const { return sss_.size(); }
76
77 // Get the status of the reader.
78 const Status status() const { return status_; }
79
80 // Checks if the reader contains any slice of a tensor. In case the reader
81 // does contain the tensor, if "shape" is not nullptr, fill "shape" with the
82 // shape of the tensor; if "type" is not nullptr, fill "type" with the type
83 // of the tensor.
84 bool HasTensor(const string& name, TensorShape* shape, DataType* type) const;
85
86 // Checks if the reader contains all the data about a tensor slice, and if
87 // yes, copies the data of the slice to "data". The caller needs to make sure
88 // that "data" points to a buffer that holds enough data.
89 // This is a slow function since it needs to read sstables.
90 template <typename T>
91 bool CopySliceData(const string& name, const TensorSlice& slice,
92 T* data) const;
93
94 // Get the tensors.
95 const std::unordered_map<string, TensorSliceSet*>& Tensors() const {
96 return tensors_;
97 }
98
99 // Returns value for one tensor. Only single slice checkpoints are supported
100 // at the moment.
101 Status GetTensor(const string& name,
102 std::unique_ptr<tensorflow::Tensor>* out_tensor) const;
103
104 typedef std::unordered_map<string, TensorShape> VarToShapeMap;
105 typedef std::unordered_map<string, DataType> VarToDataTypeMap;
106
107 // Returns a map from tensor name to shape.
108 VarToShapeMap GetVariableToShapeMap() const;
109
110 // Returns a map from tensor name to data type.
111 VarToDataTypeMap GetVariableToDataTypeMap() const;
112
113 // Returns a string containing names and shapes of all the tensors.
114 const string DebugString() const;
115
116 private:
117 friend class TensorSliceWriteTestHelper;
118
119 void LoadShard(int shard) const;
120 void LoadAllShards() const;
121
122 const TensorSliceSet* FindTensorSlice(
123 const string& name, const TensorSlice& slice,
124 std::vector<std::pair<TensorSlice, string>>* details) const;
125
126 const string filepattern_;
127 const OpenTableFunction open_function_;
128 std::vector<string> fnames_;
129 std::unordered_map<string, int> fname_to_index_;
130
131 // Guards the attributes below.
132 mutable mutex mu_;
133 mutable bool all_shards_loaded_ = false;
134 mutable std::vector<std::unique_ptr<Table>> sss_;
135 mutable std::unordered_map<string, TensorSliceSet*> tensors_;
136 mutable Status status_;
137
138 TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceReader);
139};
140
141Status OpenTableTensorSliceReader(const string& fname,
142 TensorSliceReader::Table** table);
143
144template <typename T>
145bool TensorSliceReader::CopySliceData(const string& name,
146 const TensorSlice& slice, T* data) const {
147 std::vector<std::pair<TensorSlice, string>> details;
148 const TensorSliceSet* tss;
149 {
150 mutex_lock l(mu_);
151 tss = FindTensorSlice(name, slice, &details);
152 if (!tss && !all_shards_loaded_) {
153 VLOG(1) << "Did not find slice in preferred shard, loading all shards."
154 << name << ": " << slice.DebugString();
155 LoadAllShards();
156 tss = FindTensorSlice(name, slice, &details);
157 }
158 if (!tss) {
159 // No such tensor
160 return false;
161 }
162 }
163 // We have the data -- copy it over.
164 string value;
165 for (const auto& x : details) {
166 const TensorSlice& slice_s = x.first;
167 const string& fname = x.second;
168 int idx = gtl::FindWithDefault(fname_to_index_, fname, -1);
169 CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname;
170 // We read a record in the corresponding sstable
171 const string key = EncodeTensorNameSlice(name, slice_s);
172 if (!sss_[idx]->Get(key, &value)) {
173 VLOG(1) << "Failed to seek to the record for tensor " << name
174 << ", slice " << slice_s.DebugString()
175 << ": computed key = " << key;
176 return false;
177 }
178 SavedTensorSlices sts;
179 if (!ParseProtoUnlimited(&sts, value)) {
180 VLOG(1) << "Failed to parse the record for tensor " << name << ", slice "
181 << slice_s.DebugString() << ": computed key = " << key;
182 return false;
183 }
184 // Ensure the TensorSlice contains the expected amount of data.
185 TensorShape shp_s;
186 Status s = slice_s.SliceTensorShape(tss->shape(), &shp_s);
187 if (!s.ok()) {
188 VLOG(1) << "Failed to slice tensor " << name << ", slice "
189 << slice_s.DebugString() << ": " << s;
190 return false;
191 }
192 if (checkpoint::TensorProtoDataSize<T>(sts.data().data()) !=
193 shp_s.num_elements()) {
194 VLOG(1) << "Tensor " << name << ", slice " << slice_s.DebugString()
195 << " had an unexpected amount of data: expected = "
196 << shp_s.num_elements() << ", got = "
197 << checkpoint::TensorProtoDataSize<T>(sts.data().data());
198 return false;
199 }
200 CopyDataFromTensorSliceToTensorSlice(
201 tss->shape(), slice_s, slice,
202 checkpoint::TensorProtoData<T>(sts.data().data()), data);
203 }
204 return true;
205}
206
207} // namespace checkpoint
208
209} // namespace tensorflow
210
211#endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_
212