1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/util/tensor_slice_reader.h" |
17 | |
18 | #include <utility> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/framework/types.pb.h" |
22 | #include "tensorflow/core/framework/versions.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | #include "tensorflow/core/lib/io/iterator.h" |
25 | #include "tensorflow/core/lib/io/table.h" |
26 | #include "tensorflow/core/lib/io/table_options.h" |
27 | #include "tensorflow/core/platform/env.h" |
28 | #include "tensorflow/core/platform/errors.h" |
29 | #include "tensorflow/core/platform/logging.h" |
30 | #include "tensorflow/core/platform/protobuf.h" |
31 | #include "tensorflow/core/public/version.h" |
32 | #include "tensorflow/core/util/saved_tensor_slice_util.h" |
33 | #include "tensorflow/core/util/tensor_slice_util.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | namespace checkpoint { |
38 | |
39 | TensorSliceReader::Table::~Table() {} |
40 | |
41 | namespace { |
42 | class TensorSliceReaderTable : public TensorSliceReader::Table { |
43 | public: |
44 | // Takes ownership of 'f'. |
45 | explicit TensorSliceReaderTable(RandomAccessFile* f, table::Table* t) |
46 | : file_(f), table_(t) {} |
47 | |
48 | ~TensorSliceReaderTable() override { |
49 | delete table_; |
50 | delete file_; |
51 | } |
52 | |
53 | bool Get(const string& key, string* value) override { |
54 | std::unique_ptr<table::Iterator> iter(table_->NewIterator()); |
55 | iter->Seek(key); |
56 | if (iter->Valid() && iter->key() == key) { |
57 | StringPiece v = iter->value(); |
58 | value->assign(v.data(), v.size()); |
59 | return true; |
60 | } else { |
61 | return false; |
62 | } |
63 | } |
64 | |
65 | private: |
66 | RandomAccessFile* file_; // Owns. |
67 | table::Table* table_; |
68 | }; |
69 | } // namespace |
70 | |
71 | Status OpenTableTensorSliceReader(const string& fname, |
72 | TensorSliceReader::Table** result) { |
73 | *result = nullptr; |
74 | Env* env = Env::Default(); |
75 | std::unique_ptr<RandomAccessFile> f; |
76 | Status s = env->NewRandomAccessFile(fname, &f); |
77 | if (s.ok()) { |
78 | uint64 file_size; |
79 | s = env->GetFileSize(fname, &file_size); |
80 | if (s.ok()) { |
81 | table::Options options; |
82 | table::Table* table; |
83 | s = table::Table::Open(options, f.get(), file_size, &table); |
84 | if (s.ok()) { |
85 | *result = new TensorSliceReaderTable(f.release(), table); |
86 | return OkStatus(); |
87 | } else { |
88 | s = errors::CreateWithUpdatedMessage( |
89 | s, strings::StrCat(s.error_message(), |
90 | ": perhaps your file is in a different " |
91 | "file format and you need to use a " |
92 | "different restore operator?" )); |
93 | } |
94 | } |
95 | } |
96 | LOG(WARNING) << "Could not open " << fname << ": " << s; |
97 | return s; |
98 | } |
99 | |
100 | TensorSliceReader::TensorSliceReader(const string& filepattern) |
101 | : TensorSliceReader(filepattern, OpenTableTensorSliceReader, |
102 | kLoadAllShards) {} |
103 | |
104 | TensorSliceReader::TensorSliceReader(const string& filepattern, |
105 | OpenTableFunction open_function) |
106 | : TensorSliceReader(filepattern, std::move(open_function), kLoadAllShards) { |
107 | } |
108 | |
109 | TensorSliceReader::TensorSliceReader(const string& filepattern, |
110 | OpenTableFunction open_function, |
111 | int preferred_shard) |
112 | : filepattern_(filepattern), open_function_(std::move(open_function)) { |
113 | VLOG(1) << "TensorSliceReader for " << filepattern; |
114 | Status s = Env::Default()->GetMatchingPaths(filepattern, &fnames_); |
115 | if (!s.ok()) { |
116 | status_ = errors::InvalidArgument( |
117 | "Unsuccessful TensorSliceReader constructor: " |
118 | "Failed to get matching files on " , |
119 | filepattern, ": " , s.ToString()); |
120 | return; |
121 | } |
122 | if (fnames_.empty()) { |
123 | status_ = errors::NotFound( |
124 | "Unsuccessful TensorSliceReader constructor: " |
125 | "Failed to find any matching files for " , |
126 | filepattern); |
127 | return; |
128 | } |
129 | sss_.resize(fnames_.size()); |
130 | for (size_t shard = 0; shard < fnames_.size(); ++shard) { |
131 | fname_to_index_.insert(std::make_pair(fnames_[shard], shard)); |
132 | } |
133 | if (preferred_shard == kLoadAllShards || fnames_.size() == 1 || |
134 | static_cast<size_t>(preferred_shard) >= fnames_.size()) { |
135 | LoadAllShards(); |
136 | } else { |
137 | VLOG(1) << "Loading shard " << preferred_shard << " for " << filepattern_; |
138 | LoadShard(preferred_shard); |
139 | } |
140 | } |
141 | |
142 | void TensorSliceReader::LoadShard(int shard) const { |
143 | CHECK_LT(shard, sss_.size()); |
144 | if (sss_[shard] || !status_.ok()) { |
145 | return; // Already loaded, or invalid. |
146 | } |
147 | string value; |
148 | SavedTensorSlices sts; |
149 | const string fname = fnames_[shard]; |
150 | VLOG(1) << "Reading meta data from file " << fname << "..." ; |
151 | Table* table; |
152 | Status s = open_function_(fname, &table); |
153 | if (!s.ok()) { |
154 | status_ = errors::DataLoss("Unable to open table file " , fname, ": " , |
155 | s.ToString()); |
156 | return; |
157 | } |
158 | sss_[shard].reset(table); |
159 | if (!(table->Get(kSavedTensorSlicesKey, &value) && |
160 | ParseProtoUnlimited(&sts, value))) { |
161 | status_ = errors::Internal( |
162 | "Failed to find the saved tensor slices at the beginning of the " |
163 | "checkpoint file: " , |
164 | fname); |
165 | return; |
166 | } |
167 | status_ = CheckVersions(sts.meta().versions(), TF_CHECKPOINT_VERSION, |
168 | TF_CHECKPOINT_VERSION_MIN_PRODUCER, "Checkpoint" , |
169 | "checkpoint" ); |
170 | if (!status_.ok()) return; |
171 | for (const SavedSliceMeta& ssm : sts.meta().tensor()) { |
172 | TensorShape ssm_shape; |
173 | status_ = TensorShape::BuildTensorShapeBase(ssm.shape(), &ssm_shape); |
174 | if (!status_.ok()) return; |
175 | for (const TensorSliceProto& tsp : ssm.slice()) { |
176 | TensorSlice ss_slice; |
177 | status_ = TensorSlice::BuildTensorSlice(tsp, &ss_slice); |
178 | if (!status_.ok()) return; |
179 | status_ = RegisterTensorSlice(ssm.name(), ssm_shape, ssm.type(), fname, |
180 | ss_slice, &tensors_); |
181 | if (!status_.ok()) return; |
182 | } |
183 | } |
184 | } |
185 | |
186 | void TensorSliceReader::LoadAllShards() const { |
187 | VLOG(1) << "Loading all shards for " << filepattern_; |
188 | for (size_t i = 0; i < fnames_.size() && status_.ok(); ++i) { |
189 | LoadShard(i); |
190 | } |
191 | all_shards_loaded_ = true; |
192 | } |
193 | |
194 | const TensorSliceSet* TensorSliceReader::FindTensorSlice( |
195 | const string& name, const TensorSlice& slice, |
196 | std::vector<std::pair<TensorSlice, string>>* details) const { |
197 | const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name); |
198 | if (tss && !tss->QueryMeta(slice, details)) { |
199 | return nullptr; |
200 | } |
201 | return tss; |
202 | } |
203 | |
204 | TensorSliceReader::~TensorSliceReader() { |
205 | for (auto& temp : tensors_) { |
206 | delete temp.second; |
207 | } |
208 | tensors_.clear(); |
209 | } |
210 | |
211 | bool TensorSliceReader::HasTensor(const string& name, TensorShape* shape, |
212 | DataType* type) const { |
213 | mutex_lock l(mu_); |
214 | const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name); |
215 | if (!tss && !all_shards_loaded_) { |
216 | VLOG(1) << "Did not find tensor in preferred shard, loading all shards: " |
217 | << name; |
218 | LoadAllShards(); |
219 | tss = gtl::FindPtrOrNull(tensors_, name); |
220 | } |
221 | if (tss) { |
222 | if (shape) { |
223 | *shape = tss->shape(); |
224 | } |
225 | if (type) { |
226 | *type = tss->type(); |
227 | } |
228 | return true; |
229 | } else { |
230 | return false; |
231 | } |
232 | } |
233 | |
234 | Status TensorSliceReader::GetTensor( |
235 | const string& name, std::unique_ptr<tensorflow::Tensor>* out_tensor) const { |
236 | DataType type; |
237 | TensorShape shape; |
238 | TensorSlice slice; |
239 | { |
240 | mutex_lock l(mu_); |
241 | const TensorSliceSet* tss = gtl::FindPtrOrNull(tensors_, name); |
242 | if (tss == nullptr) { |
243 | return errors::NotFound(name, " not found in checkpoint file" ); |
244 | } |
245 | |
246 | if (tss->Slices().size() > 1) { |
247 | // TODO(sherrym): Support multi-slice checkpoints. |
248 | return errors::Unimplemented("Sliced checkpoints are not supported" ); |
249 | } |
250 | |
251 | type = tss->type(); |
252 | shape = tss->shape(); |
253 | slice = tss->Slices().begin()->second.slice; |
254 | } |
255 | |
256 | std::unique_ptr<tensorflow::Tensor> t(new tensorflow::Tensor); |
257 | Status s = tensorflow::Tensor::BuildTensor(type, shape, t.get()); |
258 | if (!s.ok()) return s; |
259 | bool success = false; |
260 | |
261 | #define READER_COPY(dt) \ |
262 | case dt: \ |
263 | success = CopySliceData(name, slice, \ |
264 | t->flat<EnumToDataType<dt>::Type>().data()); \ |
265 | break; |
266 | |
267 | switch (type) { |
268 | READER_COPY(DT_FLOAT); |
269 | READER_COPY(DT_DOUBLE); |
270 | READER_COPY(DT_INT32); |
271 | READER_COPY(DT_UINT8); |
272 | READER_COPY(DT_INT16); |
273 | READER_COPY(DT_INT8); |
274 | READER_COPY(DT_INT64); |
275 | READER_COPY(DT_STRING); |
276 | default: |
277 | return errors::Unimplemented("Data type not supported" ); |
278 | } |
279 | #undef READER_COPY |
280 | |
281 | if (!success) { |
282 | return errors::NotFound(name, " not found in checkpoint file" ); |
283 | } |
284 | std::swap(*out_tensor, t); |
285 | |
286 | return OkStatus(); |
287 | } |
288 | |
289 | TensorSliceReader::VarToShapeMap TensorSliceReader::GetVariableToShapeMap() |
290 | const { |
291 | VarToShapeMap name_to_shape; |
292 | if (status().ok()) { |
293 | for (auto& e : Tensors()) { |
294 | name_to_shape[e.first] = e.second->shape(); |
295 | } |
296 | } |
297 | return name_to_shape; |
298 | } |
299 | |
300 | TensorSliceReader::VarToDataTypeMap |
301 | TensorSliceReader::GetVariableToDataTypeMap() const { |
302 | VarToDataTypeMap name_to_dtype; |
303 | if (status().ok()) { |
304 | for (auto& e : Tensors()) { |
305 | name_to_dtype[e.first] = e.second->type(); |
306 | } |
307 | } |
308 | return name_to_dtype; |
309 | } |
310 | |
311 | const string TensorSliceReader::DebugString() const { |
312 | string shape_str; |
313 | if (status().ok()) { |
314 | for (const auto& e : Tensors()) { |
315 | strings::StrAppend(&shape_str, e.first, " (" , |
316 | DataType_Name(e.second->type()), ") " , |
317 | e.second->shape().DebugString()); |
318 | // Indicates if a tensor has more than 1 slice (i.e., it's partitioned). |
319 | const int num_slices = e.second->Slices().size(); |
320 | if (num_slices > 1) { |
321 | strings::StrAppend(&shape_str, ", " , num_slices, " slices" ); |
322 | } |
323 | strings::StrAppend(&shape_str, "\n" ); |
324 | } |
325 | } |
326 | return shape_str; |
327 | } |
328 | |
329 | } // namespace checkpoint |
330 | |
331 | } // namespace tensorflow |
332 | |