1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ |
18 | |
19 | #include <iostream> |
20 | #include <type_traits> |
21 | #include <utility> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/type_index.h" |
26 | #include "tensorflow/core/framework/variant_tensor_data.h" |
27 | #include "tensorflow/core/lib/strings/strcat.h" |
28 | #include "tensorflow/core/platform/abi.h" |
29 | #include "tensorflow/core/platform/protobuf.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | // Type used for tag-dispatch of the Encode/Decode Variant implementations. This |
34 | // template can determine whether the first type parameter `T` is one of the |
35 | // following: |
36 | // |
37 | // * A POD type (TypeResolver<T, true>) |
38 | // * A tensorflow::Tensor (TypeResolver<T, false, true>) |
39 | // * A protocol buffer (TypeResolver<T, false, false, true>) |
40 | // * None of the above (TypeResolver<T, false, false, false>) |
41 | // |
42 | template <typename T, bool = std::is_pod<typename std::decay<T>::type>::value, |
43 | bool = std::is_same<typename std::decay<T>::type, |
44 | ::tensorflow::Tensor>::value, |
45 | bool = std::is_base_of<protobuf::MessageLite, |
46 | typename std::decay<T>::type>::value> |
47 | struct TypeResolver {}; |
48 | |
49 | // Specialization for POD type |
50 | template <typename T> |
51 | void EncodeVariantImpl(const T& value, TypeResolver<T, true /* is_pod */>, |
52 | VariantTensorData* data) { |
53 | data->set_metadata(value); |
54 | } |
55 | |
56 | // Specialization for tensorflow::Tensor |
57 | template <typename T> |
58 | void EncodeVariantImpl(const T& value, |
59 | TypeResolver<T, false /* is_pod */, true /* Tensor */>, |
60 | VariantTensorData* data) { |
61 | data->tensors_.clear(); |
62 | data->tensors_.push_back(value); |
63 | } |
64 | |
65 | // Specialization for protobuf |
66 | template <typename T> |
67 | void EncodeVariantImpl(const T& value, |
68 | TypeResolver<T, false /* is_pod */, false /* Tensor */, |
69 | true /* protobuf */>, |
70 | VariantTensorData* data) { |
71 | value.SerializeToString(&data->metadata_); |
72 | } |
73 | |
74 | // Specialization for other types |
75 | template <typename T> |
76 | void EncodeVariantImpl(const T& value, |
77 | TypeResolver<T, false /* is_pod */, false /* Tensor */, |
78 | false /* protobuf */>, |
79 | VariantTensorData* data) { |
80 | value.Encode(data); |
81 | } |
82 | |
83 | // Specialization for POD type |
84 | template <typename T> |
85 | bool DecodeVariantImpl(VariantTensorData data, |
86 | TypeResolver<T, true /* is_pod */, false /* Tensor */, |
87 | false /* protobuf */>, |
88 | T* value) { |
89 | return data.get_metadata(value); |
90 | } |
91 | |
92 | // Specialization for tensorflow::Tensor |
93 | template <typename T> |
94 | bool DecodeVariantImpl(VariantTensorData data, |
95 | TypeResolver<T, false /* is_pod */, true /* Tensor */, |
96 | false /* protobuf */>, |
97 | T* value) { |
98 | *value = data.tensors(0); |
99 | return true; |
100 | } |
101 | |
102 | // Specialization for protobuf |
103 | template <typename T> |
104 | bool DecodeVariantImpl(VariantTensorData data, |
105 | TypeResolver<T, false /* is_pod */, false /* Tensor */, |
106 | true /* protobuf */>, |
107 | T* value) { |
108 | std::string metadata; |
109 | data.get_metadata(&metadata); |
110 | return value->ParseFromString(std::move(metadata)); |
111 | } |
112 | |
113 | // Specialization for other types |
114 | template <typename T> |
115 | bool DecodeVariantImpl(VariantTensorData data, |
116 | TypeResolver<T, false /* is_pod */, false /* Tensor */, |
117 | false /* protobuf */>, |
118 | T* value) { |
119 | return value->Decode(std::move(data)); |
120 | } |
121 | |
122 | template <typename C, typename = void> |
123 | struct has_type_name : std::false_type {}; |
124 | |
125 | template <typename C> |
126 | struct has_type_name< |
127 | C, typename std::enable_if<std::is_same< |
128 | decltype(std::declval<C>().TypeName()), string>::value>::type> |
129 | : std::true_type {}; |
130 | |
131 | template <typename T, bool = has_type_name<typename std::decay<T>::type>::value, |
132 | bool = std::is_same<typename std::decay<T>::type, |
133 | ::tensorflow::Tensor>::value, |
134 | bool = std::is_base_of<protobuf::MessageLite, |
135 | typename std::decay<T>::type>::value> |
136 | struct TypeNameResolver {}; |
137 | |
138 | template <typename T> |
139 | std::string TypeNameVariantImpl(const T& value, |
140 | TypeNameResolver<T, true /* has_type_name */>) { |
141 | return value.TypeName(); |
142 | } |
143 | |
144 | template <typename T> |
145 | std::string TypeNameVariantImpl( |
146 | const T& value, |
147 | TypeNameResolver<T, false /* has_type_name */, true /* Tensor */>) { |
148 | return "tensorflow::Tensor" ; |
149 | } |
150 | |
151 | template <typename T> |
152 | std::string TypeNameVariantImpl( |
153 | const T& value, TypeNameResolver<T, false /* has_type_name */, |
154 | false /* Tensor */, true /* protobuf */>) { |
155 | return value.GetTypeName(); |
156 | } |
157 | |
158 | template <typename T> |
159 | std::string TypeNameVariantImpl( |
160 | const T& value, |
161 | TypeNameResolver<T, false /* has_type_name */, false /* Tensor */, |
162 | false /* protobuf */>) { |
163 | return port::MaybeAbiDemangle(TypeIndex::Make<T>().name()); |
164 | } |
165 | |
166 | template <typename T> |
167 | std::string TypeNameVariant(const T& value) { |
168 | return TypeNameVariantImpl(value, TypeNameResolver<T>()); |
169 | } |
170 | |
171 | template <typename C, typename = void> |
172 | struct has_debug_string : std::false_type {}; |
173 | |
174 | template <typename C> |
175 | struct has_debug_string< |
176 | C, typename std::enable_if<std::is_same< |
177 | decltype(std::declval<C>().DebugString()), string>::value>::type> |
178 | : std::true_type {}; |
179 | |
180 | template <typename C, typename = void> |
181 | struct can_strcat : std::false_type {}; |
182 | |
183 | template <typename C> |
184 | struct can_strcat< |
185 | C, typename std::enable_if<std::is_same< |
186 | decltype(strings::StrCat(std::declval<C>())), string>::value>::type> |
187 | : std::true_type {}; |
188 | |
189 | template <typename T, |
190 | bool = has_debug_string<typename std::decay<T>::type>::value, |
191 | bool = can_strcat<typename std::decay<T>::type>::value> |
192 | struct DebugStringResolver {}; |
193 | |
194 | // TODO(ebrevdo): Expand DebugStringResolver to return TypeString if |
195 | // there is no StrCat<T>() constructor. |
196 | template <typename T> |
197 | std::string DebugStringVariantImpl( |
198 | const T& value, DebugStringResolver<T, true /* has_debug_string */>) { |
199 | return value.DebugString(); |
200 | } |
201 | |
202 | template <typename T> |
203 | std::string DebugStringVariantImpl( |
204 | const T& value, DebugStringResolver<T, false /* has_debug_string */, |
205 | true /* can_strcat */>) { |
206 | return strings::StrCat(value); |
207 | } |
208 | |
209 | template <typename T> |
210 | std::string DebugStringVariantImpl( |
211 | const T& value, DebugStringResolver<T, false /* has_debug_string */, |
212 | false /* can_strcat */>) { |
213 | return "?" ; |
214 | } |
215 | |
216 | template <typename T> |
217 | std::string DebugStringVariant(const T& value) { |
218 | return DebugStringVariantImpl(value, DebugStringResolver<T>()); |
219 | } |
220 | |
221 | template <typename T> |
222 | void EncodeVariant(const T& value, VariantTensorData* data) { |
223 | EncodeVariantImpl(value, TypeResolver<T>(), data); |
224 | data->set_type_name(TypeNameVariant(value)); |
225 | } |
226 | |
227 | template <typename T> |
228 | bool DecodeVariant(VariantTensorData* data, T* value) { |
229 | return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value); |
230 | } |
231 | |
232 | template <typename T> |
233 | void EncodeVariant(const T& value, std::string* buf) { |
234 | VariantTensorData data; |
235 | EncodeVariantImpl(value, TypeResolver<T>(), &data); |
236 | data.set_type_name(TypeNameVariant(value)); |
237 | DCHECK(buf != nullptr); |
238 | data.SerializeToString(buf); |
239 | } |
240 | |
241 | template <typename T> |
242 | bool DecodeVariant(std::string* buf, T* value) { |
243 | VariantTensorData data; |
244 | if (!data.ParseFromString(*buf)) return false; |
245 | if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) { |
246 | return false; |
247 | } |
248 | return true; |
249 | } |
250 | |
251 | // Specializations for VariantTensorDataProto |
252 | template <> |
253 | std::string TypeNameVariant(const VariantTensorDataProto& value); |
254 | |
255 | template <> |
256 | void EncodeVariant(const VariantTensorDataProto& value, |
257 | VariantTensorData* data); |
258 | |
259 | template <> |
260 | bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value); |
261 | |
262 | template <> |
263 | void EncodeVariant(const VariantTensorDataProto& value, std::string* buf); |
264 | |
265 | template <> |
266 | bool DecodeVariant(std::string* buf, VariantTensorDataProto* value); |
267 | |
268 | // Encodes an array of Variant objects in to the given StringListEncoder. |
269 | // `variant_array` is assumed to point to an array of `n` Variant objects. |
270 | void EncodeVariantList(const Variant* variant_array, int64_t n, |
271 | std::unique_ptr<port::StringListEncoder> e); |
272 | |
273 | // Decodes an array of Variant objects from the given StringListDecoder. |
274 | // `variant_array` is assumed to point to an array of `n` Variant objects. |
275 | bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d, |
276 | Variant* variant_array, int64_t n); |
277 | |
278 | } // end namespace tensorflow |
279 | |
280 | #endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ |
281 | |