1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/util/tensor_slice_reader.h"
17
18#include <utility>
19#include <vector>
20
21#include "tensorflow/core/framework/types.pb.h"
22#include "tensorflow/core/framework/versions.h"
23#include "tensorflow/core/lib/core/errors.h"
24#include "tensorflow/core/lib/io/iterator.h"
25#include "tensorflow/core/lib/io/table.h"
26#include "tensorflow/core/lib/io/table_options.h"
27#include "tensorflow/core/platform/env.h"
28#include "tensorflow/core/platform/errors.h"
29#include "tensorflow/core/platform/logging.h"
30#include "tensorflow/core/platform/protobuf.h"
31#include "tensorflow/core/public/version.h"
32#include "tensorflow/core/util/saved_tensor_slice_util.h"
33#include "tensorflow/core/util/tensor_slice_util.h"
34
35namespace tensorflow {
36
37namespace checkpoint {
38
39TensorSliceReader::Table::~Table() {}
40
41namespace {
42class TensorSliceReaderTable : public TensorSliceReader::Table {
43 public:
44 // Takes ownership of 'f'.
45 explicit TensorSliceReaderTable(RandomAccessFile* f, table::Table* t)
46 : file_(f), table_(t) {}
47
48 ~TensorSliceReaderTable() override {
49 delete table_;
50 delete file_;
51 }
52
53 bool Get(const string& key, string* value) override {
54 std::unique_ptr<table::Iterator> iter(table_->NewIterator());
55 iter->Seek(key);
56 if (iter->Valid() && iter->key() == key) {
57 StringPiece v = iter->value();
58 value->assign(v.data(), v.size());
59 return true;
60 } else {
61 return false;
62 }
63 }
64
65 private:
66 RandomAccessFile* file_; // Owns.
67 table::Table* table_;
68};
69} // namespace
70
71Status OpenTableTensorSliceReader(const string& fname,
72 TensorSliceReader::Table** result) {
73 *result = nullptr;
74 Env* env = Env::Default();
75 std::unique_ptr<RandomAccessFile> f;
76 Status s = env->NewRandomAccessFile(fname, &f);
77 if (s.ok()) {
78 uint64 file_size;
79 s = env->GetFileSize(fname, &file_size);
80 if (s.ok()) {
81 table::Options options;
82 table::Table* table;
83 s = table::Table::Open(options, f.get(), file_size, &table);
84 if (s.ok()) {
85 *result = new TensorSliceReaderTable(f.release(), table);
86 return OkStatus();
87 } else {
88 s = errors::CreateWithUpdatedMessage(
89 s, strings::StrCat(s.error_message(),
90 ": perhaps your file is in a different "
91 "file format and you need to use a "
92 "different restore operator?"));
93 }
94 }
95 }
96 LOG(WARNING) << "Could not open " << fname << ": " << s;
97 return s;
98}
99
100TensorSliceReader::TensorSliceReader(const string& filepattern)
101 : TensorSliceReader(filepattern, OpenTableTensorSliceReader,
102 kLoadAllShards) {}
103
104TensorSliceReader::TensorSliceReader(const string& filepattern,
105 OpenTableFunction open_function)
106 : TensorSliceReader(filepattern, std::move(open_function), kLoadAllShards) {
107}
108
109TensorSliceReader::TensorSliceReader(const string& filepattern,
110 OpenTableFunction open_function,
111 int preferred_shard)
112 : filepattern_(filepattern), open_function_(std::move(open_function)) {
113 VLOG(1) << "TensorSliceReader for " << filepattern;
114 Status s = Env::Default()->GetMatchingPaths(filepattern, &fnames_);
115 if (!s.ok()) {
116 status_ = errors::InvalidArgument(
117 "Unsuccessful TensorSliceReader constructor: "
118 "Failed to get matching files on ",
119 filepattern, ": ", s.ToString());
120 return;
121 }
122 if (fnames_.empty()) {
123 status_ = errors::NotFound(
124 "Unsuccessful TensorSliceReader constructor: "
125 "Failed to find any matching files for ",
126 filepattern);
127 return;
128 }
129 sss_.resize(fnames_.size());
130 for (size_t shard = 0; shard < fnames_.size(); ++shard) {
131 fname_to_index_.insert(std::make_pair(fnames_[shard], shard));
132 }
133 if (preferred_shard == kLoadAllShards || fnames_.size() == 1 ||
134 static_cast<size_t>(preferred_shard) >= fnames_.size()) {
135 LoadAllShards();
136 } else {
137 VLOG(1) << "Loading shard " << preferred_shard << " for " << filepattern_;
138 LoadShard(preferred_shard);
139 }
140}
141
142void TensorSliceReader::LoadShard(int shard) const {
143 CHECK_LT(shard, sss_.size());
144 if (sss_[shard] || !status_.ok()) {
145 return; // Already loaded, or invalid.
146 }
147 string value;
148 SavedTensorSlices sts;
149 const string fname = fnames_[shard];
150 VLOG(1) << "Reading meta data from file " << fname << "...";
151 Table* table;
152 Status s = open_function_(fname, &table);
153 if (!s.ok()) {
154 status_ = errors::DataLoss("Unable to open table file ", fname, ": ",
155 s.ToString());
156 return;
157 }
158 sss_[shard].reset(table);
159 if (!(table->Get(kSavedTensorSlicesKey, &value) &&
160 ParseProtoUnlimited(&sts, value))) {
161 status_ = errors::Internal(
162 "Failed to find the saved tensor slices at the beginning of the "
163 "checkpoint file: ",
164 fname);
165 return;
166 }
167 status_ = CheckVersions(sts.meta().versions(), TF_CHECKPOINT_VERSION,
168 TF_CHECKPOINT_VERSION_MIN_PRODUCER, "Checkpoint",
169 "checkpoint");
170 if (!status_.ok()) return;
171 for (const SavedSliceMeta& ssm : sts.meta().tensor()) {
172 TensorShape ssm_shape;
173 status_ = TensorShape::BuildTensorShapeBase(ssm.shape(), &ssm_shape);
174 if (!status_.ok()) return;
175 for (const TensorSliceProto& tsp : ssm.slice()) {
176 TensorSlice ss_slice;
177 status_ = TensorSlice::BuildTensorSlice(tsp, &ss_slice);
178 if (!status_.ok()) return;
179 status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname,
180 ss_slice, &tensors_);
181 if (!status_.ok()) return;
182 }
183 }
184}
185
186void TensorSliceReader::LoadAllShards() const {
187 VLOG(1) << "Loading all shards for " << filepattern_;
188 for (size_t i = 0; i < fnames_.size() && status_.ok(); ++i) {
189 LoadShard(i);
190 }
191 all_shards_loaded_ = true;
192}
193
194const TensorSliceSet* TensorSliceReader::FindTensorSlice(
195 const string& name, const TensorSlice& slice,
196 std::vector<std::pair<TensorSlice, string>>* details) const {
197 const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
198 if (tss && !tss->QueryMeta(slice, details)) {
199 return nullptr;
200 }
201 return tss;
202}
203
204TensorSliceReader::~TensorSliceReader() {
205 for (auto& temp : tensors_) {
206 delete temp.second;
207 }
208 tensors_.clear();
209}
210
211bool TensorSliceReader::HasTensor(const string& name, TensorShape* shape,
212 DataType* type) const {
213 mutex_lock l(mu_);
214 const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
215 if (!tss && !all_shards_loaded_) {
216 VLOG(1) << "Did not find tensor in preferred shard, loading all shards: "
217 << name;
218 LoadAllShards();
219 tss = gtl::FindPtrOrNull(tensors_, name);
220 }
221 if (tss) {
222 if (shape) {
223 *shape = tss->shape();
224 }
225 if (type) {
226 *type = tss->type();
227 }
228 return true;
229 } else {
230 return false;
231 }
232}
233
234Status TensorSliceReader::GetTensor(
235 const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor) const {
236 DataType type;
237 TensorShape shape;
238 TensorSlice slice;
239 {
240 mutex_lock l(mu_);
241 const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name);
242 if (tss == nullptr) {
243 return errors::NotFound(name, " not found in checkpoint file");
244 }
245
246 if (tss->Slices().size() > 1) {
247 // TODO(sherrym): Support multi-slice checkpoints.
248 return errors::Unimplemented("Sliced checkpoints are not supported");
249 }
250
251 type = tss->type();
252 shape = tss->shape();
253 slice = tss->Slices().begin()->second.slice;
254 }
255
256 std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor);
257 Status s = tensorflow::Tensor::BuildTensor(type, shape, t.get());
258 if (!s.ok()) return s;
259 bool success = false;
260
261#define READER_COPY(dt) \
262 case dt: \
263 success = CopySliceData(name, slice, \
264 t->flat<EnumToDataType<dt>::Type>().data()); \
265 break;
266
267 switch (type) {
268 READER_COPY(DT_FLOAT);
269 READER_COPY(DT_DOUBLE);
270 READER_COPY(DT_INT32);
271 READER_COPY(DT_UINT8);
272 READER_COPY(DT_INT16);
273 READER_COPY(DT_INT8);
274 READER_COPY(DT_INT64);
275 READER_COPY(DT_STRING);
276 default:
277 return errors::Unimplemented("Data type not supported");
278 }
279#undef READER_COPY
280
281 if (!success) {
282 return errors::NotFound(name, " not found in checkpoint file");
283 }
284 std::swap(*out_tensor, t);
285
286 return OkStatus();
287}
288
289TensorSliceReader::VarToShapeMap TensorSliceReader::GetVariableToShapeMap()
290 const {
291 VarToShapeMap name_to_shape;
292 if (status().ok()) {
293 for (auto& e : Tensors()) {
294 name_to_shape[e.first] = e.second->shape();
295 }
296 }
297 return name_to_shape;
298}
299
300TensorSliceReader::VarToDataTypeMap
301TensorSliceReader::GetVariableToDataTypeMap() const {
302 VarToDataTypeMap name_to_dtype;
303 if (status().ok()) {
304 for (auto& e : Tensors()) {
305 name_to_dtype[e.first] = e.second->type();
306 }
307 }
308 return name_to_dtype;
309}
310
311const string TensorSliceReader::DebugString() const {
312 string shape_str;
313 if (status().ok()) {
314 for (const auto& e : Tensors()) {
315 strings::StrAppend(&shape_str, e.first, " (",
316 DataType_Name(e.second->type()), ") ",
317 e.second->shape().DebugString());
318 // Indicates if a tensor has more than 1 slice (i.e., it's partitioned).
319 const int num_slices = e.second->Slices().size();
320 if (num_slices > 1) {
321 strings::StrAppend(&shape_str, ", ", num_slices, " slices");
322 }
323 strings::StrAppend(&shape_str, "\n");
324 }
325 }
326 return shape_str;
327}
328
329} // namespace checkpoint
330
331} // namespace tensorflow
332