1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Utilities for saving/restoring tensor slice checkpoints.
17
18#ifndef TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
19#define TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
20
21#include <string> // for string
22#include "tensorflow/core/framework/tensor.pb.h"
23#include "tensorflow/core/framework/tensor_slice.h"
24#include "tensorflow/core/framework/types.h"
25#include "tensorflow/core/lib/core/status.h" // for Status
26#include "tensorflow/core/platform/protobuf.h"
27
28namespace tensorflow {
29
30namespace checkpoint {
31
32// The key for the metadata in the tensor slice checkpoint files. It is "" so
33// that the metadata is always at the beginning of a checkpoint file.
34extern const char kSavedTensorSlicesKey[];
35
36// Encode a tensor name + a tensor slice into an ordered code and outputs it as
37// a string.
38// The format is
39// <0>
40// <tensor_name>
41// <rank>
42// <dim-0-start><dim-0-length>
43// <dim-1-start><dim-1-length>
44// ...
45
46string EncodeTensorNameSlice(const string& name,
47 const tensorflow::TensorSlice& slice);
48
49// Parse out the name and the slice from string encoded as an ordered code.
50Status DecodeTensorNameSlice(const string& code, string* name,
51 tensorflow::TensorSlice* slice);
52
53// Extracts the full shape, slice spec, and shape of the slice from
54// "shape_and_slice". On non-OK return, caller must clear the out-arguments
55// before reusing.
56Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape,
57 TensorSlice* slice, TensorShape* shape_slice);
58
59template <typename T>
60struct SaveTypeTraits;
61
62template <typename T>
63int TensorProtoDataSize(const TensorProto& t);
64
65template <typename T>
66const typename SaveTypeTraits<T>::SavedType* TensorProtoData(
67 const TensorProto& t);
68
69template <typename T>
70typename SaveTypeTraits<T>::RepeatedField* MutableTensorProtoData(
71 TensorProto* t);
72
73template <typename T>
74void Fill(T* data, size_t n, TensorProto* t);
75
76#define TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, STYPE) \
77 template <> \
78 struct SaveTypeTraits<TYPE> { \
79 static constexpr bool supported = true; \
80 typedef STYPE SavedType; \
81 typedef protobuf::RepeatedField<FTYPE> RepeatedField; \
82 }; \
83 template <> \
84 inline const STYPE* TensorProtoData<TYPE>(const TensorProto& t) { \
85 static_assert(SaveTypeTraits<TYPE>::supported, \
86 "Specified type " #TYPE " not supported for Restore"); \
87 return reinterpret_cast<const STYPE*>(t.FIELD##_val().data()); \
88 } \
89 template <> \
90 inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>( \
91 TensorProto * t) { \
92 static_assert(SaveTypeTraits<TYPE>::supported, \
93 "Specified type " #TYPE " not supported for Save"); \
94 return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>( \
95 t->mutable_##FIELD##_val()); \
96 }
97
98#define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \
99 TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, FTYPE) \
100 template <> \
101 inline int TensorProtoDataSize<TYPE>(const TensorProto& t) { \
102 return t.FIELD##_val_size(); \
103 } \
104 template <> \
105 inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \
106 typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \
107 t->mutable_##FIELD##_val()->Swap(&copy); \
108 }
109
110// Complex needs special treatment since proto doesn't have native complex
111#define TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(TYPE, FIELD, FTYPE) \
112 TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, TYPE) \
113 template <> \
114 inline int TensorProtoDataSize<TYPE>(const TensorProto& t) { \
115 return t.FIELD##_val_size() / 2; \
116 } \
117 template <> \
118 inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \
119 const FTYPE* sub = reinterpret_cast<const FTYPE*>(data); \
120 typename protobuf::RepeatedField<FTYPE> copy(sub, sub + 2 * n); \
121 t->mutable_##FIELD##_val()->Swap(&copy); \
122 }
123
124TENSOR_PROTO_EXTRACT_TYPE(bool, bool, bool);
125TENSOR_PROTO_EXTRACT_TYPE(float, float, float);
126TENSOR_PROTO_EXTRACT_TYPE(double, double, double);
127TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex64, scomplex, float);
128TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex128, dcomplex, double);
129TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32);
130TENSOR_PROTO_EXTRACT_TYPE(uint32, uint32, uint32);
131TENSOR_PROTO_EXTRACT_TYPE(int64_t, int64, protobuf_int64);
132TENSOR_PROTO_EXTRACT_TYPE(uint64, uint64, protobuf_uint64);
133TENSOR_PROTO_EXTRACT_TYPE(uint16, int, int32);
134TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32);
135TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32);
136TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32);
137TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32);
138TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32);
139TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32);
140
141#undef TENSOR_PROTO_EXTRACT_TYPE_COMPLEX
142#undef TENSOR_PROTO_EXTRACT_TYPE_HELPER
143#undef TENSOR_PROTO_EXTRACT_TYPE
144
145// Custom implementation for qint32, based on the one for int32.
146
147template <>
148struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {};
149
150template <>
151inline int TensorProtoDataSize<qint32>(const TensorProto& t) {
152 return t.int_val_size();
153}
154
155template <>
156inline const int32* TensorProtoData<qint32>(const TensorProto& t) {
157 static_assert(SaveTypeTraits<qint32>::supported,
158 "Specified type qint32 not supported for Restore");
159 return reinterpret_cast<const int32*>(t.int_val().data());
160}
161
162inline void Fill(const qint32* data, size_t n, TensorProto* t) {
163 const int32* p = reinterpret_cast<const int32*>(data);
164 typename protobuf::RepeatedField<int32> copy(p, p + n);
165 t->mutable_int_val()->Swap(&copy);
166}
167
168// Custom implementation for Eigen::half.
169
170template <>
171struct SaveTypeTraits<Eigen::half> {
172 static constexpr bool supported = true;
173 typedef int SavedType;
174 typedef protobuf::RepeatedField<int32> RepeatedField;
175};
176
177template <>
178inline int TensorProtoDataSize<Eigen::half>(const TensorProto& t) {
179 return t.half_val_size();
180}
181
182template <>
183inline const int* TensorProtoData<Eigen::half>(const TensorProto& t) {
184 return t.half_val().data();
185}
186
187template <>
188inline protobuf::RepeatedField<int32>* MutableTensorProtoData<Eigen::half>(
189 TensorProto* t) {
190 return t->mutable_half_val();
191}
192
193template <>
194inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) {
195 typename protobuf::RepeatedField<int32>* val = t->mutable_half_val();
196 val->Resize(n, 0);
197 for (size_t i = 0; i < n; ++i) {
198 val->Set(i, Eigen::numext::bit_cast<uint16>(data[i]));
199 }
200}
201
202// Custom implementation for string.
203
204template <>
205struct SaveTypeTraits<tstring> {
206 static constexpr bool supported = true;
207 typedef const string* SavedType;
208 typedef protobuf::RepeatedPtrField<string> RepeatedField;
209};
210
211template <>
212inline int TensorProtoDataSize<tstring>(const TensorProto& t) {
213 return t.string_val_size();
214}
215
216template <>
217inline const string* const* TensorProtoData<tstring>(const TensorProto& t) {
218 static_assert(SaveTypeTraits<tstring>::supported,
219 "Specified type tstring not supported for Restore");
220 return t.string_val().data();
221}
222
223template <>
224inline protobuf::RepeatedPtrField<string>* MutableTensorProtoData<tstring>(
225 TensorProto* t) {
226 static_assert(SaveTypeTraits<tstring>::supported,
227 "Specified type tstring not supported for Save");
228 return t->mutable_string_val();
229}
230
231template <>
232inline void Fill(const tstring* data, size_t n, TensorProto* t) {
233 typename protobuf::RepeatedPtrField<string> copy(data, data + n);
234 t->mutable_string_val()->Swap(&copy);
235}
236
237} // namespace checkpoint
238
239} // namespace tensorflow
240
241#endif // TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
242