1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // The utility to write checkpoints for google brain tensor ops and v3 |
17 | // checkpoints for dist_belief. |
18 | |
19 | #ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_WRITER_H_ |
20 | #define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_WRITER_H_ |
21 | |
22 | #include <unordered_map> |
23 | |
24 | #include "tensorflow/core/framework/tensor_shape.h" |
25 | #include "tensorflow/core/framework/tensor_slice.h" |
26 | #include "tensorflow/core/framework/types.h" |
27 | #include "tensorflow/core/lib/core/errors.h" |
28 | #include "tensorflow/core/lib/core/status.h" |
29 | #include "tensorflow/core/lib/core/stringpiece.h" |
30 | #include "tensorflow/core/lib/gtl/map_util.h" |
31 | #include "tensorflow/core/lib/strings/stringprintf.h" |
32 | #include "tensorflow/core/platform/logging.h" |
33 | #include "tensorflow/core/platform/macros.h" |
34 | #include "tensorflow/core/platform/types.h" |
35 | #include "tensorflow/core/util/saved_tensor_slice.pb.h" |
36 | #include "tensorflow/core/util/saved_tensor_slice_util.h" |
37 | |
38 | namespace tensorflow { |
39 | |
40 | namespace checkpoint { |
41 | |
42 | class TensorSliceWriter { |
43 | public: |
44 | // Abstract interface that TensorSliceWriter uses for building |
45 | class Builder { |
46 | public: |
47 | virtual ~Builder() {} |
48 | virtual void Add(StringPiece key, StringPiece value) = 0; |
49 | virtual Status Finish(int64_t* file_size) = 0; |
50 | }; |
51 | typedef std::function<Status(const string&, Builder**)> CreateBuilderFunction; |
52 | |
53 | TensorSliceWriter(const string& filename, |
54 | CreateBuilderFunction create_builder); |
55 | virtual ~TensorSliceWriter() {} |
56 | // Adds a slice. We support float and int32 for now. |
57 | // TODO(yangke): add more supports |
58 | template <typename T> |
59 | Status Add(const string& name, const TensorShape& shape, |
60 | const TensorSlice& slice, const T* data); |
61 | Status Finish(); |
62 | |
63 | // Allocate "num_elements" elements in "ss" and save the data in "data" |
64 | // there. |
65 | template <typename T> |
66 | static Status SaveData(const T* data, int64_t num_elements, SavedSlice* ss); |
67 | |
68 | static size_t MaxBytesPerElement(DataType dt); |
69 | |
70 | private: |
71 | static size_t MaxBytesPerElementOrZero(DataType dt); |
72 | |
73 | static constexpr size_t kMaxMessageBytes = 1LL << 31; |
74 | // Filling in the TensorProto in a SavedSlice will add the following |
75 | // header bytes, in addition to the data: |
76 | // - 1 byte: TensorProto tag and wire format |
77 | // - <= 5 bytes: TensorProto length |
78 | // - 1 byte: Repeated *_val tag and wire format |
79 | // - <= 5 bytes: *_val length |
80 | // However, we add 1KB of slack, to be conservative and guard |
81 | // against other additions to the TensorProto. |
82 | static constexpr size_t = 1 << 10; |
83 | |
84 | const string filename_; |
85 | const CreateBuilderFunction create_builder_; |
86 | const string tmpname_; |
87 | |
88 | // A mapping from the tensor names to their index in meta_.saved_slice_meta() |
89 | std::unordered_map<string, int> name_to_index_; |
90 | // The metadata that holds all the saved tensor slices. |
91 | SavedTensorSlices sts_; |
92 | // The data to be written to the builder |
93 | std::map<string, string> data_; |
94 | // Total number of slices written |
95 | int slices_; |
96 | TF_DISALLOW_COPY_AND_ASSIGN(TensorSliceWriter); |
97 | }; |
98 | |
99 | template <typename T> |
100 | Status TensorSliceWriter::Add(const string& name, const TensorShape& shape, |
101 | const TensorSlice& slice, const T* data) { |
102 | // The tensor and the slice have to be compatible |
103 | if (shape.dims() != slice.dims()) { |
104 | return errors::Internal("Incompatible tensor shape and slice: " , "shape = " , |
105 | shape.DebugString(), |
106 | ", slice = " , slice.DebugString()); |
107 | } |
108 | DataType dt = DataTypeToEnum<T>::value; |
109 | // We need to add an entry for "name" if there isn't an entry already. |
110 | int index = gtl::FindWithDefault(name_to_index_, name, -1); |
111 | if (index >= 0) { |
112 | // The same tensor has been registered -- we verify that the shapes and the |
113 | // type agree. |
114 | const SavedSliceMeta& ssm = sts_.meta().tensor(index); |
115 | CHECK_EQ(name, ssm.name()) << ssm.ShortDebugString(); |
116 | TensorShape ssm_shape(ssm.shape()); |
117 | if (!shape.IsSameSize(ssm_shape)) { |
118 | return errors::Internal( |
119 | "Mismatching shapes: existing tensor = " , ssm_shape.DebugString(), |
120 | ", trying to add name " , name, ", shape = " , shape.DebugString()); |
121 | } |
122 | if (dt != ssm.type()) { |
123 | return errors::Internal( |
124 | "Mismatching types: existing type = " , DataTypeString(ssm.type()), |
125 | ", trying to add name " , name, ", type = " , DataTypeString(dt)); |
126 | } |
127 | } else { |
128 | // Insert the new tensor name with the shape information |
129 | index = sts_.meta().tensor_size(); |
130 | name_to_index_.insert(std::make_pair(name, index)); |
131 | SavedSliceMeta* ssm = sts_.mutable_meta()->add_tensor(); |
132 | ssm->set_name(name); |
133 | shape.AsProto(ssm->mutable_shape()); |
134 | ssm->set_type(dt); |
135 | } |
136 | // Now we need to add the slice info the list of slices. |
137 | SavedSliceMeta* ssm = sts_.mutable_meta()->mutable_tensor(index); |
138 | slice.AsProto(ssm->add_slice()); |
139 | |
140 | // Now we need to add the real data. |
141 | { |
142 | SavedTensorSlices sts; |
143 | SavedSlice* ss = sts.mutable_data(); |
144 | ss->set_name(name); |
145 | slice.AsProto(ss->mutable_slice()); |
146 | TensorShape saved_shape(ssm->shape()); |
147 | TensorShape sliced_shape; |
148 | TF_RETURN_IF_ERROR(slice.SliceTensorShape(saved_shape, &sliced_shape)); |
149 | TF_RETURN_IF_ERROR(SaveData(data, sliced_shape.num_elements(), ss)); |
150 | string key = EncodeTensorNameSlice(name, slice); |
151 | // TODO(yangke): consider doing a two-pass thing where the first pass just |
152 | // list the tensor slices we want to save and then another pass to actually |
153 | // set the data. Need to figure out if the interface works well. |
154 | std::pair<string, string> key_value(key, "" ); |
155 | if (!sts.AppendToString(&key_value.second)) { |
156 | return errors::Internal("Error writing Tensor. Possible size overflow." ); |
157 | } |
158 | data_.insert(key_value); |
159 | } |
160 | ++slices_; |
161 | return OkStatus(); |
162 | } |
163 | |
164 | template <typename T> |
165 | Status TensorSliceWriter::SaveData(const T* data, int64_t num_elements, |
166 | SavedSlice* ss) { |
167 | size_t max_bytes_per_element = |
168 | MaxBytesPerElementOrZero(DataTypeToEnum<T>::value); |
169 | if (max_bytes_per_element == 0) { |
170 | return errors::InvalidArgument( |
171 | "Tensor slice serialization not implemented for dtype " , |
172 | DataTypeToEnum<T>::value); |
173 | } |
174 | size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes + |
175 | (max_bytes_per_element * num_elements); |
176 | if (size_bound > kMaxMessageBytes) { |
177 | return errors::InvalidArgument( |
178 | "Tensor slice is too large to serialize (conservative estimate: " , |
179 | size_bound, " bytes)" ); |
180 | } |
181 | Fill(data, num_elements, ss->mutable_data()); |
182 | DCHECK_GE(ss->ByteSize(), 0); |
183 | DCHECK_LE(ss->ByteSize(), size_bound); |
184 | return OkStatus(); |
185 | } |
186 | |
187 | template <> |
188 | Status TensorSliceWriter::SaveData(const tstring* data, int64_t num_elements, |
189 | SavedSlice* ss); |
190 | |
191 | // Create a table builder that will write to "filename" in |
192 | // tensorflow::io::Table format. If successful, return OK |
193 | // and set "*builder" to the allocated builder. Otherwise, return a |
194 | // non-OK status. |
195 | Status CreateTableTensorSliceBuilder(const string& filename, |
196 | TensorSliceWriter::Builder** builder); |
197 | |
198 | } // namespace checkpoint |
199 | |
200 | } // namespace tensorflow |
201 | |
202 | #endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_WRITER_H_ |
203 | |