1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
17#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
18
19#include <iostream>
20#include <type_traits>
21#include <utility>
22#include <vector>
23
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/type_index.h"
26#include "tensorflow/core/framework/variant_tensor_data.h"
27#include "tensorflow/core/lib/strings/strcat.h"
28#include "tensorflow/core/platform/abi.h"
29#include "tensorflow/core/platform/protobuf.h"
30
31namespace tensorflow {
32
33// Type used for tag-dispatch of the Encode/Decode Variant implementations. This
34// template can determine whether the first type parameter `T` is one of the
35// following:
36//
37// * A POD type (TypeResolver<T, true>)
38// * A tensorflow::Tensor (TypeResolver<T, false, true>)
39// * A protocol buffer (TypeResolver<T, false, false, true>)
40// * None of the above (TypeResolver<T, false, false, false>)
41//
42template <typename T, bool = std::is_pod<typename std::decay<T>::type>::value,
43 bool = std::is_same<typename std::decay<T>::type,
44 ::tensorflow::Tensor>::value,
45 bool = std::is_base_of<protobuf::MessageLite,
46 typename std::decay<T>::type>::value>
47struct TypeResolver {};
48
49// Specialization for POD type
50template <typename T>
51void EncodeVariantImpl(const T& value, TypeResolver<T, true /* is_pod */>,
52 VariantTensorData* data) {
53 data->set_metadata(value);
54}
55
56// Specialization for tensorflow::Tensor
57template <typename T>
58void EncodeVariantImpl(const T& value,
59 TypeResolver<T, false /* is_pod */, true /* Tensor */>,
60 VariantTensorData* data) {
61 data->tensors_.clear();
62 data->tensors_.push_back(value);
63}
64
65// Specialization for protobuf
66template <typename T>
67void EncodeVariantImpl(const T& value,
68 TypeResolver<T, false /* is_pod */, false /* Tensor */,
69 true /* protobuf */>,
70 VariantTensorData* data) {
71 value.SerializeToString(&data->metadata_);
72}
73
74// Specialization for other types
75template <typename T>
76void EncodeVariantImpl(const T& value,
77 TypeResolver<T, false /* is_pod */, false /* Tensor */,
78 false /* protobuf */>,
79 VariantTensorData* data) {
80 value.Encode(data);
81}
82
83// Specialization for POD type
84template <typename T>
85bool DecodeVariantImpl(VariantTensorData data,
86 TypeResolver<T, true /* is_pod */, false /* Tensor */,
87 false /* protobuf */>,
88 T* value) {
89 return data.get_metadata(value);
90}
91
92// Specialization for tensorflow::Tensor
93template <typename T>
94bool DecodeVariantImpl(VariantTensorData data,
95 TypeResolver<T, false /* is_pod */, true /* Tensor */,
96 false /* protobuf */>,
97 T* value) {
98 *value = data.tensors(0);
99 return true;
100}
101
102// Specialization for protobuf
103template <typename T>
104bool DecodeVariantImpl(VariantTensorData data,
105 TypeResolver<T, false /* is_pod */, false /* Tensor */,
106 true /* protobuf */>,
107 T* value) {
108 std::string metadata;
109 data.get_metadata(&metadata);
110 return value->ParseFromString(std::move(metadata));
111}
112
113// Specialization for other types
114template <typename T>
115bool DecodeVariantImpl(VariantTensorData data,
116 TypeResolver<T, false /* is_pod */, false /* Tensor */,
117 false /* protobuf */>,
118 T* value) {
119 return value->Decode(std::move(data));
120}
121
122template <typename C, typename = void>
123struct has_type_name : std::false_type {};
124
125template <typename C>
126struct has_type_name<
127 C, typename std::enable_if<std::is_same<
128 decltype(std::declval<C>().TypeName()), string>::value>::type>
129 : std::true_type {};
130
131template <typename T, bool = has_type_name<typename std::decay<T>::type>::value,
132 bool = std::is_same<typename std::decay<T>::type,
133 ::tensorflow::Tensor>::value,
134 bool = std::is_base_of<protobuf::MessageLite,
135 typename std::decay<T>::type>::value>
136struct TypeNameResolver {};
137
138template <typename T>
139std::string TypeNameVariantImpl(const T& value,
140 TypeNameResolver<T, true /* has_type_name */>) {
141 return value.TypeName();
142}
143
144template <typename T>
145std::string TypeNameVariantImpl(
146 const T& value,
147 TypeNameResolver<T, false /* has_type_name */, true /* Tensor */>) {
148 return "tensorflow::Tensor";
149}
150
151template <typename T>
152std::string TypeNameVariantImpl(
153 const T& value, TypeNameResolver<T, false /* has_type_name */,
154 false /* Tensor */, true /* protobuf */>) {
155 return value.GetTypeName();
156}
157
158template <typename T>
159std::string TypeNameVariantImpl(
160 const T& value,
161 TypeNameResolver<T, false /* has_type_name */, false /* Tensor */,
162 false /* protobuf */>) {
163 return port::MaybeAbiDemangle(TypeIndex::Make<T>().name());
164}
165
166template <typename T>
167std::string TypeNameVariant(const T& value) {
168 return TypeNameVariantImpl(value, TypeNameResolver<T>());
169}
170
171template <typename C, typename = void>
172struct has_debug_string : std::false_type {};
173
174template <typename C>
175struct has_debug_string<
176 C, typename std::enable_if<std::is_same<
177 decltype(std::declval<C>().DebugString()), string>::value>::type>
178 : std::true_type {};
179
180template <typename C, typename = void>
181struct can_strcat : std::false_type {};
182
183template <typename C>
184struct can_strcat<
185 C, typename std::enable_if<std::is_same<
186 decltype(strings::StrCat(std::declval<C>())), string>::value>::type>
187 : std::true_type {};
188
189template <typename T,
190 bool = has_debug_string<typename std::decay<T>::type>::value,
191 bool = can_strcat<typename std::decay<T>::type>::value>
192struct DebugStringResolver {};
193
194// TODO(ebrevdo): Expand DebugStringResolver to return TypeString if
195// there is no StrCat<T>() constructor.
196template <typename T>
197std::string DebugStringVariantImpl(
198 const T& value, DebugStringResolver<T, true /* has_debug_string */>) {
199 return value.DebugString();
200}
201
202template <typename T>
203std::string DebugStringVariantImpl(
204 const T& value, DebugStringResolver<T, false /* has_debug_string */,
205 true /* can_strcat */>) {
206 return strings::StrCat(value);
207}
208
209template <typename T>
210std::string DebugStringVariantImpl(
211 const T& value, DebugStringResolver<T, false /* has_debug_string */,
212 false /* can_strcat */>) {
213 return "?";
214}
215
216template <typename T>
217std::string DebugStringVariant(const T& value) {
218 return DebugStringVariantImpl(value, DebugStringResolver<T>());
219}
220
221template <typename T>
222void EncodeVariant(const T& value, VariantTensorData* data) {
223 EncodeVariantImpl(value, TypeResolver<T>(), data);
224 data->set_type_name(TypeNameVariant(value));
225}
226
227template <typename T>
228bool DecodeVariant(VariantTensorData* data, T* value) {
229 return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value);
230}
231
232template <typename T>
233void EncodeVariant(const T& value, std::string* buf) {
234 VariantTensorData data;
235 EncodeVariantImpl(value, TypeResolver<T>(), &data);
236 data.set_type_name(TypeNameVariant(value));
237 DCHECK(buf != nullptr);
238 data.SerializeToString(buf);
239}
240
241template <typename T>
242bool DecodeVariant(std::string* buf, T* value) {
243 VariantTensorData data;
244 if (!data.ParseFromString(*buf)) return false;
245 if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
246 return false;
247 }
248 return true;
249}
250
251// Specializations for VariantTensorDataProto
252template <>
253std::string TypeNameVariant(const VariantTensorDataProto& value);
254
255template <>
256void EncodeVariant(const VariantTensorDataProto& value,
257 VariantTensorData* data);
258
259template <>
260bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
261
262template <>
263void EncodeVariant(const VariantTensorDataProto& value, std::string* buf);
264
265template <>
266bool DecodeVariant(std::string* buf, VariantTensorDataProto* value);
267
268// Encodes an array of Variant objects in to the given StringListEncoder.
269// `variant_array` is assumed to point to an array of `n` Variant objects.
270void EncodeVariantList(const Variant* variant_array, int64_t n,
271 std::unique_ptr<port::StringListEncoder> e);
272
273// Decodes an array of Variant objects from the given StringListDecoder.
274// `variant_array` is assumed to point to an array of `n` Variant objects.
275bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
276 Variant* variant_array, int64_t n);
277
278} // end namespace tensorflow
279
280#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
281