1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ |
18 | |
19 | #include <algorithm> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/tensor.h" |
23 | #include "tensorflow/core/framework/tensor.pb.h" |
24 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
25 | #include "tensorflow/core/framework/type_traits.h" |
26 | #include "tensorflow/core/platform/protobuf.h" |
27 | #include "tensorflow/core/platform/types.h" |
28 | |
29 | namespace tensorflow { |
30 | namespace tensor { |
31 | |
32 | // DeepCopy returns a tensor whose contents are a deep copy of the |
33 | // contents of 'other'. This function is intended only for |
34 | // convenience, not speed. |
35 | // |
36 | // REQUIRES: 'other' must point to data stored in CPU memory. |
37 | // REQUIRES: 'other' must be a Tensor of a copy-able type if |
38 | // 'other' is not appropriately memory-aligned. |
39 | Tensor DeepCopy(const Tensor& other); |
40 | |
41 | // Deep copies input to output. This function is similar to above, but assumes |
42 | // that the memory for the output has already been allocated. |
43 | void DeepCopy(const Tensor& input, Tensor* output); |
44 | |
45 | // Concatenates 'tensors' into a single tensor, along their 0th dimension. |
46 | // |
47 | // REQUIRES: All members of 'tensors' must have the same data type parameter. |
48 | // REQUIRES: Each member of 'tensors' must have at least one dimension. |
49 | // REQUIRES: Each member of 'tensors' must point to data stored in CPU memory. |
50 | // REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it |
51 | // is not appropriately memory-aligned. |
52 | Status Concat(const gtl::ArraySlice<Tensor>& tensors, |
53 | Tensor* result) TF_MUST_USE_RESULT; |
54 | |
55 | // Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th |
56 | // dimension. The ith output tensor has 0th-dimension size 'sizes[i]'. |
57 | // |
58 | // REQUIRES: 'tensor' must have at least one dimension. |
59 | // REQUIRES: 'tensor.dim_size(0)' must equal the sum of the elements of 'sizes'. |
60 | // REQUIRES: 'tensor' must point to data stored in CPU memory. |
61 | // REQUIRES: 'tensor' must be a Tensor of a copy-able type if it is not |
62 | // appropriately memory-aligned. |
63 | // |
64 | // Split() and Concat() are inverse operations. |
65 | Status Split(const Tensor& tensor, const gtl::ArraySlice<int64_t>& sizes, |
66 | std::vector<Tensor>* result) TF_MUST_USE_RESULT; |
67 | |
68 | namespace internal { |
69 | void SetTensorProtoShape(std::vector<size_t> shape, |
70 | TensorShapeProto* shape_proto); |
71 | |
72 | template <typename Type> |
73 | class TensorProtoFieldHelper : public std::false_type {}; |
74 | |
75 | #define DEFINE_PROTO_FIELD_HELPER(TYPE, FIELDNAME) \ |
76 | template <> \ |
77 | class TensorProtoFieldHelper<TYPE> : public std::true_type { \ |
78 | public: \ |
79 | typedef decltype( \ |
80 | std::declval<TensorProto>().FIELDNAME##_val(0)) FieldType; \ |
81 | typedef decltype( \ |
82 | std::declval<TensorProto>().FIELDNAME##_val()) RepeatedFieldType; \ |
83 | typedef decltype(std::declval<TensorProto>().mutable_##FIELDNAME##_val()) \ |
84 | MutableRepeatedFieldType; \ |
85 | static MutableRepeatedFieldType GetMutableField(TensorProto* proto) { \ |
86 | return proto->mutable_##FIELDNAME##_val(); \ |
87 | } \ |
88 | static RepeatedFieldType& GetField(const TensorProto& proto) { \ |
89 | return proto.FIELDNAME##_val(); \ |
90 | } \ |
91 | } |
92 | |
93 | // The argument pairs in the following macro instantiations encode the |
94 | // mapping from C++ type ($1) to repeated field name "$2_val" used for storing |
95 | // values in TensorProto. See tensorflow/core/framework/tensor.proto. |
96 | DEFINE_PROTO_FIELD_HELPER(float, float); |
97 | DEFINE_PROTO_FIELD_HELPER(double, double); |
98 | DEFINE_PROTO_FIELD_HELPER(int8, int); |
99 | DEFINE_PROTO_FIELD_HELPER(uint8, int); |
100 | DEFINE_PROTO_FIELD_HELPER(int16, int); |
101 | DEFINE_PROTO_FIELD_HELPER(uint16, int); |
102 | DEFINE_PROTO_FIELD_HELPER(int32, int); |
103 | DEFINE_PROTO_FIELD_HELPER(uint32, uint32); |
104 | DEFINE_PROTO_FIELD_HELPER(int64_t, int64); |
105 | DEFINE_PROTO_FIELD_HELPER(uint64, uint64); |
106 | DEFINE_PROTO_FIELD_HELPER(bool, bool); |
107 | DEFINE_PROTO_FIELD_HELPER(qint8, int); |
108 | DEFINE_PROTO_FIELD_HELPER(quint8, int); |
109 | DEFINE_PROTO_FIELD_HELPER(qint16, int); |
110 | DEFINE_PROTO_FIELD_HELPER(quint16, int); |
111 | DEFINE_PROTO_FIELD_HELPER(qint32, int); |
112 | DEFINE_PROTO_FIELD_HELPER(Eigen::half, half); |
113 | DEFINE_PROTO_FIELD_HELPER(bfloat16, half); |
114 | DEFINE_PROTO_FIELD_HELPER(complex64, scomplex); |
115 | DEFINE_PROTO_FIELD_HELPER(complex128, dcomplex); |
116 | |
117 | #undef DEFINE_PROTO_HELPER |
118 | |
119 | template <typename T> |
120 | struct CopyHelper { |
121 | template <typename SrcIter, typename DstIter> |
122 | static void ToArray(SrcIter begin, SrcIter end, DstIter dst) { |
123 | using SrcType = typename std::iterator_traits<SrcIter>::value_type; |
124 | using DstType = typename std::iterator_traits<DstIter>::value_type; |
125 | std::transform(begin, end, dst, [](const SrcType& x) -> DstType { |
126 | return static_cast<DstType>(x); |
127 | }); |
128 | } |
129 | template <typename SrcIter> |
130 | static void ToArray(SrcIter begin, SrcIter end, SrcIter dst) { |
131 | std::copy(begin, end, dst); |
132 | } |
133 | template <typename SrcIter, typename DstIter> |
134 | static void FromArray(SrcIter begin, SrcIter end, DstIter dst) { |
135 | ToArray(begin, end, dst); |
136 | } |
137 | }; |
138 | |
139 | // Overloads for Eigen::half and bfloat16 that are 16 bits in size but are |
140 | // stored in an int32 field. |
141 | template <> |
142 | struct CopyHelper<Eigen::half> { |
143 | template <typename SrcIter> |
144 | static void ToArray(SrcIter begin, SrcIter end, Eigen::half* dst) { |
145 | std::transform(begin, end, dst, [](int x) -> Eigen::half { |
146 | return Eigen::numext::bit_cast<Eigen::half>(static_cast<uint16>(x)); |
147 | }); |
148 | } |
149 | template <typename SrcIter, typename DstIter> |
150 | static void FromArray(SrcIter begin, SrcIter end, DstIter dst) { |
151 | std::transform(begin, end, dst, [](Eigen::half h) -> int { |
152 | return static_cast<int>(Eigen::numext::bit_cast<uint16>(h)); |
153 | }); |
154 | } |
155 | }; |
156 | |
157 | template <> |
158 | struct CopyHelper<bfloat16> { |
159 | template <typename SrcIter> |
160 | static void ToArray(SrcIter begin, SrcIter end, bfloat16* dst) { |
161 | std::transform(begin, end, dst, [](int x) -> bfloat16 { |
162 | return Eigen::numext::bit_cast<bfloat16>(static_cast<uint16>(x)); |
163 | }); |
164 | } |
165 | template <typename SrcIter, typename DstIter> |
166 | static void FromArray(SrcIter begin, SrcIter end, DstIter dst) { |
167 | std::transform(begin, end, dst, [](bfloat16 bf16) -> int { |
168 | return static_cast<int>(Eigen::numext::bit_cast<uint16>(bf16)); |
169 | }); |
170 | } |
171 | }; |
172 | |
173 | // Overloads for complex types that store real and imaginary parts |
174 | // at indices 2*i and 2*i+1 in float or double field. |
175 | template <typename RealType> |
176 | struct CopyHelper<std::complex<RealType>> { |
177 | template <typename SrcIter> |
178 | static void ToArray(SrcIter begin, SrcIter end, std::complex<RealType>* dst) { |
179 | RealType* real_dst = reinterpret_cast<RealType*>(dst); |
180 | std::copy(begin, end, real_dst); |
181 | } |
182 | |
183 | template <typename SrcIter, typename DstIter> |
184 | static void FromArray(SrcIter begin, SrcIter end, DstIter dst) { |
185 | size_t n = std::distance(begin, end); |
186 | const RealType* real_begin = reinterpret_cast<const RealType*>(&(*begin)); |
187 | std::copy_n(real_begin, 2 * n, dst); |
188 | } |
189 | }; |
190 | |
191 | // Helper class to extract and insert values into TensorProto represented as |
192 | // repeated fields. |
193 | template <typename T> |
194 | class TensorProtoHelper : public std::true_type { |
195 | public: |
196 | using FieldHelper = TensorProtoFieldHelper<T>; |
197 | using FieldType = typename TensorProtoFieldHelper<T>::FieldType; |
198 | |
199 | static DataType GetDataType() { return DataTypeToEnum<T>::value; } |
200 | |
201 | // Returns the number of values of type T encoded in the proto. |
202 | static size_t NumValues(const TensorProto& proto) { |
203 | size_t raw_size = FieldHelper::GetField(proto).size(); |
204 | return is_complex<T>::value ? raw_size / 2 : raw_size; |
205 | } |
206 | |
207 | static void AddValue(const T& value, TensorProto* proto) { |
208 | const T* val_ptr = &value; |
209 | AddValues(val_ptr, val_ptr + 1, proto); |
210 | } |
211 | |
212 | static T GetValue(size_t index, const TensorProto& proto) { |
213 | const size_t stride = is_complex<T>::value ? 2 : 1; |
214 | T val; |
215 | CopyHelper<T>::ToArray( |
216 | FieldHelper::GetField(proto).begin() + stride * index, |
217 | FieldHelper::GetField(proto).begin() + stride * (index + 1), &val); |
218 | return val; |
219 | } |
220 | |
221 | template <typename IterType> |
222 | static void AddValues(IterType begin, IterType end, TensorProto* proto) { |
223 | size_t n = std::distance(begin, end); |
224 | FieldType* dst = AppendUninitialized(n, proto); |
225 | CopyHelper<T>::FromArray(begin, end, dst); |
226 | } |
227 | |
228 | template <typename IterType> |
229 | static void CopyValues(IterType dst, const TensorProto& proto) { |
230 | CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin(), |
231 | FieldHelper::GetField(proto).end(), dst); |
232 | } |
233 | |
234 | static void Truncate(size_t new_size, TensorProto* proto) { |
235 | if (is_complex<T>::value) new_size *= 2; |
236 | FieldHelper::GetMutableField(proto)->Truncate(new_size); |
237 | } |
238 | |
239 | static FieldType* AppendUninitialized(size_t n, TensorProto* proto) { |
240 | if (is_complex<T>::value) n *= 2; |
241 | auto* field = FieldHelper::GetMutableField(proto); |
242 | field->Reserve(field->size() + n); |
243 | return reinterpret_cast<FieldType*>(field->AddNAlreadyReserved(n)); |
244 | } |
245 | }; |
246 | |
247 | // Specialization for string. |
248 | template <> |
249 | class TensorProtoHelper<string> : public std::true_type { |
250 | public: |
251 | static DataType GetDataType() { return DataType::DT_STRING; } |
252 | static void AddValue(const string& value, TensorProto* proto) { |
253 | *proto->mutable_string_val()->Add() = value; |
254 | } |
255 | template <typename IterType> |
256 | static void AddValues(IterType begin, IterType end, TensorProto* proto) { |
257 | for (IterType it = begin; it != end; ++it) { |
258 | AddValue(*it, proto); |
259 | } |
260 | } |
261 | template <typename IterType> |
262 | static void CopyToTensorContent(IterType begin, IterType end, |
263 | TensorProto* proto) { |
264 | AddValues(begin, end, proto); |
265 | } |
266 | }; |
267 | |
268 | } // namespace internal |
269 | |
270 | // Creates a 'TensorProto' with specified shape and values. |
271 | // The dtype and a field to represent data values of the returned 'TensorProto' |
272 | // are determined based on type of the 'values' parameter. |
273 | template <typename Type> |
274 | typename std::enable_if<internal::TensorProtoHelper<Type>::value, |
275 | TensorProto>::type |
276 | CreateTensorProto(const std::vector<Type>& values, |
277 | const std::vector<size_t>& shape) { |
278 | TensorProto tensor; |
279 | TensorShapeProto tensor_shape_proto; |
280 | internal::SetTensorProtoShape(shape, &tensor_shape_proto); |
281 | if (TensorShape(tensor_shape_proto).num_elements() != values.size()) { |
282 | LOG(ERROR) << "Shape and number of values (" << values.size() |
283 | << ") are incompatible." ; |
284 | return tensor; |
285 | } |
286 | using TypeHelper = internal::TensorProtoHelper<Type>; |
287 | tensor.set_dtype(TypeHelper::GetDataType()); |
288 | tensor.mutable_tensor_shape()->Swap(&tensor_shape_proto); |
289 | TypeHelper::AddValues(values.begin(), values.end(), &tensor); |
290 | return tensor; |
291 | } |
292 | |
293 | // Converts values in tensor to run-length encoded compressed form. |
294 | // |
295 | // The elements of a tensor can be stored in a TensorProto in one of the |
296 | // following two forms: |
297 | // 1. As a raw byte string in the field `tensor_content` containing the |
298 | // serialized in-memory representation of the tensor. |
299 | // 2. As values of a repeated field depending on the datatype, e.g. that |
300 | // values of a DT_FLOAT tensor would be stored in the repeated field |
301 | // `float_val`. |
302 | // Storage scheme 2 may use a simple form of run-length encoding to compress |
303 | // data: If the values contains a tail of identical values, the repeated field |
304 | // will be truncated such that the number of values in the repeated field is |
305 | // less than the number of elements implied by the field`tensor_shape`. The |
306 | // original tensor can be recovered by repeating the final value in the repeated |
307 | // field. |
308 | // |
309 | // The TensorProto will be compressed if a) the tensor contains at least |
310 | // min_num_elements elements and b) the compressed tensor proto is would be at |
311 | // most the size of the original tensor proto divided by min_compression_ratio. |
312 | // |
313 | // Returns true if the tensor was compressed. |
314 | bool CompressTensorProtoInPlace(int64_t min_num_elements, |
315 | float min_compression_ratio, |
316 | TensorProto* tensor); |
317 | |
318 | inline bool CompressTensorProtoInPlace(TensorProto* tensor) { |
319 | static const int64_t kDefaultMinNumElements = 64; |
320 | static const float kDefaultMinCompressionRatio = 2.0f; |
321 | return CompressTensorProtoInPlace(kDefaultMinNumElements, |
322 | kDefaultMinCompressionRatio, tensor); |
323 | } |
324 | |
325 | // Make a TensorShape from the contents of shape_t. Shape_t must be a |
326 | // 1-dimensional tensor of type int32 or int64. |
327 | Status MakeShape(const Tensor& shape_t, TensorShape* out); |
328 | |
329 | } // namespace tensor |
330 | } // namespace tensorflow |
331 | |
332 | #endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ |
333 | |