1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // The utility to read checkpoints for google brain tensor ops and v3 |
17 | // checkpoints for dist_belief. |
18 | |
19 | #ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_ |
20 | #define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_ |
21 | |
22 | #include <unordered_map> |
23 | |
24 | #include <vector> |
25 | #include "tensorflow/core/framework/tensor.h" |
26 | #include "tensorflow/core/framework/tensor_shape.h" |
27 | #include "tensorflow/core/framework/tensor_slice.h" |
28 | #include "tensorflow/core/framework/types.pb.h" |
29 | #include "tensorflow/core/lib/core/status.h" |
30 | #include "tensorflow/core/lib/core/stringpiece.h" |
31 | #include "tensorflow/core/lib/gtl/map_util.h" |
32 | #include "tensorflow/core/platform/logging.h" |
33 | #include "tensorflow/core/platform/macros.h" |
34 | #include "tensorflow/core/platform/mutex.h" |
35 | #include "tensorflow/core/platform/protobuf.h" |
36 | #include "tensorflow/core/platform/types.h" |
37 | #include "tensorflow/core/util/saved_tensor_slice.pb.h" |
38 | #include "tensorflow/core/util/saved_tensor_slice_util.h" |
39 | #include "tensorflow/core/util/tensor_slice_set.h" |
40 | #include "tensorflow/core/util/tensor_slice_util.h" |
41 | |
42 | namespace tensorflow { |
43 | |
44 | namespace checkpoint { |
45 | |
46 | // The reader reads in all the meta data about all the tensor slices. Then it |
47 | // will try to read the relevant data on-demand to produce the data for the |
48 | // slices needed. |
49 | // NOTE(yangke): another way to do this is to first load a list of the tensor |
50 | // slices needed and then just selectively read some of the meta data. That |
51 | // might optimize the loading but makes the logic a bit more complicated. We |
52 | // might want to revisit that. |
53 | // TODO(yangke): consider moving to TensorProto. |
54 | class TensorSliceReader { |
55 | public: |
56 | // Abstract interface for reading data out of a tensor slice checkpoint file |
57 | class Table { |
58 | public: |
59 | virtual ~Table(); |
60 | virtual bool Get(const string& key, string* value) = 0; |
61 | }; |
62 | typedef std::function<Status(const string&, Table**)> OpenTableFunction; |
63 | |
64 | static constexpr int kLoadAllShards = -1; |
65 | TensorSliceReader(const string& filepattern); |
66 | TensorSliceReader(const string& filepattern, OpenTableFunction open_function); |
67 | TensorSliceReader(const string& filepattern, OpenTableFunction open_function, |
68 | int preferred_shard); |
69 | virtual ~TensorSliceReader(); |
70 | |
71 | // Get the filename this reader is attached to. |
72 | const string& filepattern() const { return filepattern_; } |
73 | |
74 | // Get the number of files matched. |
75 | int num_files() const { return sss_.size(); } |
76 | |
77 | // Get the status of the reader. |
78 | const Status status() const { return status_; } |
79 | |
80 | // Checks if the reader contains any slice of a tensor. In case the reader |
81 | // does contain the tensor, if "shape" is not nullptr, fill "shape" with the |
82 | // shape of the tensor; if "type" is not nullptr, fill "type" with the type |
83 | // of the tensor. |
84 | bool HasTensor(const string& name, TensorShape* shape, DataType* type) const; |
85 | |
86 | // Checks if the reader contains all the data about a tensor slice, and if |
87 | // yes, copies the data of the slice to "data". The caller needs to make sure |
88 | // that "data" points to a buffer that holds enough data. |
89 | // This is a slow function since it needs to read sstables. |
90 | template <typename T> |
91 | bool CopySliceData(const string& name, const TensorSlice& slice, |
92 | T* data) const; |
93 | |
94 | // Get the tensors. |
95 | const std::unordered_map<string, TensorSliceSet*>& Tensors() const { |
96 | return tensors_; |
97 | } |
98 | |
99 | // Returns value for one tensor. Only single slice checkpoints are supported |
100 | // at the moment. |
101 | Status GetTensor(const string& name, |
102 | std::unique_ptr<tensorflow::Tensor>* out_tensor) const; |
103 | |
104 | typedef std::unordered_map<string, TensorShape> VarToShapeMap; |
105 | typedef std::unordered_map<string, DataType> VarToDataTypeMap; |
106 | |
107 | // Returns a map from tensor name to shape. |
108 | VarToShapeMap GetVariableToShapeMap() const; |
109 | |
110 | // Returns a map from tensor name to data type. |
111 | VarToDataTypeMap GetVariableToDataTypeMap() const; |
112 | |
113 | // Returns a string containing names and shapes of all the tensors. |
114 | const string DebugString() const; |
115 | |
116 | private: |
117 | friend class TensorSliceWriteTestHelper; |
118 | |
119 | void LoadShard(int shard) const; |
120 | void LoadAllShards() const; |
121 | |
122 | const TensorSliceSet* FindTensorSlice( |
123 | const string& name, const TensorSlice& slice, |
124 | std::vector<std::pair<TensorSlice, string>>* details) const; |
125 | |
126 | const string filepattern_; |
127 | const OpenTableFunction open_function_; |
128 | std::vector<string> fnames_; |
129 | std::unordered_map<string, int> fname_to_index_; |
130 | |
131 | // Guards the attributes below. |
132 | mutable mutex mu_; |
133 | mutable bool all_shards_loaded_ = false; |
134 | mutable std::vector<std::unique_ptr<Table>> sss_; |
135 | mutable std::unordered_map<string, TensorSliceSet*> tensors_; |
136 | mutable Status status_; |
137 | |
138 | TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceReader); |
139 | }; |
140 | |
141 | Status OpenTableTensorSliceReader(const string& fname, |
142 | TensorSliceReader::Table** table); |
143 | |
144 | template <typename T> |
145 | bool TensorSliceReader::CopySliceData(const string& name, |
146 | const TensorSlice& slice, T* data) const { |
147 | std::vector<std::pair<TensorSlice, string>> details; |
148 | const TensorSliceSet* tss; |
149 | { |
150 | mutex_lock l(mu_); |
151 | tss = FindTensorSlice(name, slice, &details); |
152 | if (!tss && !all_shards_loaded_) { |
153 | VLOG(1) << "Did not find slice in preferred shard, loading all shards." |
154 | << name << ": " << slice.DebugString(); |
155 | LoadAllShards(); |
156 | tss = FindTensorSlice(name, slice, &details); |
157 | } |
158 | if (!tss) { |
159 | // No such tensor |
160 | return false; |
161 | } |
162 | } |
163 | // We have the data -- copy it over. |
164 | string value; |
165 | for (const auto& x : details) { |
166 | const TensorSlice& slice_s = x.first; |
167 | const string& fname = x.second; |
168 | int idx = gtl::FindWithDefault(fname_to_index_, fname, -1); |
169 | CHECK_GE(idx, 0) << "Failed to find the index for filename " << fname; |
170 | // We read a record in the corresponding sstable |
171 | const string key = EncodeTensorNameSlice(name, slice_s); |
172 | if (!sss_[idx]->Get(key, &value)) { |
173 | VLOG(1) << "Failed to seek to the record for tensor " << name |
174 | << ", slice " << slice_s.DebugString() |
175 | << ": computed key = " << key; |
176 | return false; |
177 | } |
178 | SavedTensorSlices sts; |
179 | if (!ParseProtoUnlimited(&sts, value)) { |
180 | VLOG(1) << "Failed to parse the record for tensor " << name << ", slice " |
181 | << slice_s.DebugString() << ": computed key = " << key; |
182 | return false; |
183 | } |
184 | // Ensure the TensorSlice contains the expected amount of data. |
185 | TensorShape shp_s; |
186 | Status s = slice_s.SliceTensorShape(tss->shape(), &shp_s); |
187 | if (!s.ok()) { |
188 | VLOG(1) << "Failed to slice tensor " << name << ", slice " |
189 | << slice_s.DebugString() << ": " << s; |
190 | return false; |
191 | } |
192 | if (checkpoint::TensorProtoDataSize<T>(sts.data().data()) != |
193 | shp_s.num_elements()) { |
194 | VLOG(1) << "Tensor " << name << ", slice " << slice_s.DebugString() |
195 | << " had an unexpected amount of data: expected = " |
196 | << shp_s.num_elements() << ", got = " |
197 | << checkpoint::TensorProtoDataSize<T>(sts.data().data()); |
198 | return false; |
199 | } |
200 | CopyDataFromTensorSliceToTensorSlice( |
201 | tss->shape(), slice_s, slice, |
202 | checkpoint::TensorProtoData<T>(sts.data().data()), data); |
203 | } |
204 | return true; |
205 | } |
206 | |
207 | } // namespace checkpoint |
208 | |
209 | } // namespace tensorflow |
210 | |
211 | #endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_ |
212 | |