1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Utilities for saving/restoring tensor slice checkpoints. |
17 | |
18 | #ifndef TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ |
19 | #define TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ |
20 | |
21 | #include <string> // for string |
22 | #include "tensorflow/core/framework/tensor.pb.h" |
23 | #include "tensorflow/core/framework/tensor_slice.h" |
24 | #include "tensorflow/core/framework/types.h" |
25 | #include "tensorflow/core/lib/core/status.h" // for Status |
26 | #include "tensorflow/core/platform/protobuf.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | namespace checkpoint { |
31 | |
32 | // The key for the metadata in the tensor slice checkpoint files. It is "" so |
33 | // that the metadata is always at the beginning of a checkpoint file. |
34 | extern const char kSavedTensorSlicesKey[]; |
35 | |
36 | // Encode a tensor name + a tensor slice into an ordered code and outputs it as |
37 | // a string. |
38 | // The format is |
39 | // <0> |
40 | // <tensor_name> |
41 | // <rank> |
42 | // <dim-0-start><dim-0-length> |
43 | // <dim-1-start><dim-1-length> |
44 | // ... |
45 | |
46 | string EncodeTensorNameSlice(const string& name, |
47 | const tensorflow::TensorSlice& slice); |
48 | |
49 | // Parse out the name and the slice from string encoded as an ordered code. |
50 | Status DecodeTensorNameSlice(const string& code, string* name, |
51 | tensorflow::TensorSlice* slice); |
52 | |
53 | // Extracts the full shape, slice spec, and shape of the slice from |
54 | // "shape_and_slice". On non-OK return, caller must clear the out-arguments |
55 | // before reusing. |
56 | Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape, |
57 | TensorSlice* slice, TensorShape* shape_slice); |
58 | |
59 | template <typename T> |
60 | struct SaveTypeTraits; |
61 | |
62 | template <typename T> |
63 | int TensorProtoDataSize(const TensorProto& t); |
64 | |
65 | template <typename T> |
66 | const typename SaveTypeTraits<T>::SavedType* TensorProtoData( |
67 | const TensorProto& t); |
68 | |
69 | template <typename T> |
70 | typename SaveTypeTraits<T>::RepeatedField* MutableTensorProtoData( |
71 | TensorProto* t); |
72 | |
73 | template <typename T> |
74 | void Fill(T* data, size_t n, TensorProto* t); |
75 | |
76 | #define (TYPE, FIELD, FTYPE, STYPE) \ |
77 | template <> \ |
78 | struct SaveTypeTraits<TYPE> { \ |
79 | static constexpr bool supported = true; \ |
80 | typedef STYPE SavedType; \ |
81 | typedef protobuf::RepeatedField<FTYPE> RepeatedField; \ |
82 | }; \ |
83 | template <> \ |
84 | inline const STYPE* TensorProtoData<TYPE>(const TensorProto& t) { \ |
85 | static_assert(SaveTypeTraits<TYPE>::supported, \ |
86 | "Specified type " #TYPE " not supported for Restore"); \ |
87 | return reinterpret_cast<const STYPE*>(t.FIELD##_val().data()); \ |
88 | } \ |
89 | template <> \ |
90 | inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>( \ |
91 | TensorProto * t) { \ |
92 | static_assert(SaveTypeTraits<TYPE>::supported, \ |
93 | "Specified type " #TYPE " not supported for Save"); \ |
94 | return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>( \ |
95 | t->mutable_##FIELD##_val()); \ |
96 | } |
97 | |
98 | #define (TYPE, FIELD, FTYPE) \ |
99 | TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, FTYPE) \ |
100 | template <> \ |
101 | inline int TensorProtoDataSize<TYPE>(const TensorProto& t) { \ |
102 | return t.FIELD##_val_size(); \ |
103 | } \ |
104 | template <> \ |
105 | inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ |
106 | typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \ |
107 | t->mutable_##FIELD##_val()->Swap(©); \ |
108 | } |
109 | |
110 | // Complex needs special treatment since proto doesn't have native complex |
111 | #define (TYPE, FIELD, FTYPE) \ |
112 | TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, TYPE) \ |
113 | template <> \ |
114 | inline int TensorProtoDataSize<TYPE>(const TensorProto& t) { \ |
115 | return t.FIELD##_val_size() / 2; \ |
116 | } \ |
117 | template <> \ |
118 | inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ |
119 | const FTYPE* sub = reinterpret_cast<const FTYPE*>(data); \ |
120 | typename protobuf::RepeatedField<FTYPE> copy(sub, sub + 2 * n); \ |
121 | t->mutable_##FIELD##_val()->Swap(©); \ |
122 | } |
123 | |
124 | TENSOR_PROTO_EXTRACT_TYPE(bool, bool, bool); |
125 | TENSOR_PROTO_EXTRACT_TYPE(float, float, float); |
126 | TENSOR_PROTO_EXTRACT_TYPE(double, double, double); |
127 | TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex64, scomplex, float); |
128 | TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex128, dcomplex, double); |
129 | TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32); |
130 | TENSOR_PROTO_EXTRACT_TYPE(uint32, uint32, uint32); |
131 | TENSOR_PROTO_EXTRACT_TYPE(int64_t, int64, protobuf_int64); |
132 | TENSOR_PROTO_EXTRACT_TYPE(uint64, uint64, protobuf_uint64); |
133 | TENSOR_PROTO_EXTRACT_TYPE(uint16, int, int32); |
134 | TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32); |
135 | TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32); |
136 | TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32); |
137 | TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32); |
138 | TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32); |
139 | TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32); |
140 | |
141 | #undef TENSOR_PROTO_EXTRACT_TYPE_COMPLEX |
142 | #undef TENSOR_PROTO_EXTRACT_TYPE_HELPER |
143 | #undef TENSOR_PROTO_EXTRACT_TYPE |
144 | |
145 | // Custom implementation for qint32, based on the one for int32. |
146 | |
147 | template <> |
148 | struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {}; |
149 | |
150 | template <> |
151 | inline int TensorProtoDataSize<qint32>(const TensorProto& t) { |
152 | return t.int_val_size(); |
153 | } |
154 | |
155 | template <> |
156 | inline const int32* TensorProtoData<qint32>(const TensorProto& t) { |
157 | static_assert(SaveTypeTraits<qint32>::supported, |
158 | "Specified type qint32 not supported for Restore" ); |
159 | return reinterpret_cast<const int32*>(t.int_val().data()); |
160 | } |
161 | |
162 | inline void Fill(const qint32* data, size_t n, TensorProto* t) { |
163 | const int32* p = reinterpret_cast<const int32*>(data); |
164 | typename protobuf::RepeatedField<int32> copy(p, p + n); |
165 | t->mutable_int_val()->Swap(©); |
166 | } |
167 | |
168 | // Custom implementation for Eigen::half. |
169 | |
170 | template <> |
171 | struct SaveTypeTraits<Eigen::half> { |
172 | static constexpr bool supported = true; |
173 | typedef int SavedType; |
174 | typedef protobuf::RepeatedField<int32> RepeatedField; |
175 | }; |
176 | |
177 | template <> |
178 | inline int TensorProtoDataSize<Eigen::half>(const TensorProto& t) { |
179 | return t.half_val_size(); |
180 | } |
181 | |
182 | template <> |
183 | inline const int* TensorProtoData<Eigen::half>(const TensorProto& t) { |
184 | return t.half_val().data(); |
185 | } |
186 | |
187 | template <> |
188 | inline protobuf::RepeatedField<int32>* MutableTensorProtoData<Eigen::half>( |
189 | TensorProto* t) { |
190 | return t->mutable_half_val(); |
191 | } |
192 | |
193 | template <> |
194 | inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) { |
195 | typename protobuf::RepeatedField<int32>* val = t->mutable_half_val(); |
196 | val->Resize(n, 0); |
197 | for (size_t i = 0; i < n; ++i) { |
198 | val->Set(i, Eigen::numext::bit_cast<uint16>(data[i])); |
199 | } |
200 | } |
201 | |
202 | // Custom implementation for string. |
203 | |
204 | template <> |
205 | struct SaveTypeTraits<tstring> { |
206 | static constexpr bool supported = true; |
207 | typedef const string* SavedType; |
208 | typedef protobuf::RepeatedPtrField<string> RepeatedField; |
209 | }; |
210 | |
211 | template <> |
212 | inline int TensorProtoDataSize<tstring>(const TensorProto& t) { |
213 | return t.string_val_size(); |
214 | } |
215 | |
216 | template <> |
217 | inline const string* const* TensorProtoData<tstring>(const TensorProto& t) { |
218 | static_assert(SaveTypeTraits<tstring>::supported, |
219 | "Specified type tstring not supported for Restore" ); |
220 | return t.string_val().data(); |
221 | } |
222 | |
223 | template <> |
224 | inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<tstring>( |
225 | TensorProto* t) { |
226 | static_assert(SaveTypeTraits<tstring>::supported, |
227 | "Specified type tstring not supported for Save" ); |
228 | return t->mutable_string_val(); |
229 | } |
230 | |
231 | template <> |
232 | inline void Fill(const tstring* data, size_t n, TensorProto* t) { |
233 | typename protobuf::RepeatedPtrField<string> copy(data, data + n); |
234 | t->mutable_string_val()->Swap(©); |
235 | } |
236 | |
237 | } // namespace checkpoint |
238 | |
239 | } // namespace tensorflow |
240 | |
241 | #endif // TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ |
242 | |