1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
17#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
18
19#include <algorithm>
20#include <vector>
21
22#include "tensorflow/core/framework/tensor.h"
23#include "tensorflow/core/framework/tensor.pb.h"
24#include "tensorflow/core/framework/tensor_shape.pb.h"
25#include "tensorflow/core/framework/type_traits.h"
26#include "tensorflow/core/platform/protobuf.h"
27#include "tensorflow/core/platform/types.h"
28
29namespace tensorflow {
30namespace tensor {
31
32// DeepCopy returns a tensor whose contents are a deep copy of the
33// contents of 'other'. This function is intended only for
34// convenience, not speed.
35//
36// REQUIRES: 'other' must point to data stored in CPU memory.
37// REQUIRES: 'other' must be a Tensor of a copy-able type if
38// 'other' is not appropriately memory-aligned.
39Tensor DeepCopy(const Tensor& other);
40
41// Deep copies input to output. This function is similar to above, but assumes
42// that the memory for the output has already been allocated.
43void DeepCopy(const Tensor& input, Tensor* output);
44
45// Concatenates 'tensors' into a single tensor, along their 0th dimension.
46//
47// REQUIRES: All members of 'tensors' must have the same data type parameter.
48// REQUIRES: Each member of 'tensors' must have at least one dimension.
49// REQUIRES: Each member of 'tensors' must point to data stored in CPU memory.
50// REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it
51// is not appropriately memory-aligned.
52Status Concat(const gtl::ArraySlice<Tensor>& tensors,
53 Tensor* result) TF_MUST_USE_RESULT;
54
55// Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th
56// dimension. The ith output tensor has 0th-dimension size 'sizes[i]'.
57//
58// REQUIRES: 'tensor' must have at least one dimension.
59// REQUIRES: 'tensor.dim_size(0)' must equal the sum of the elements of 'sizes'.
60// REQUIRES: 'tensor' must point to data stored in CPU memory.
61// REQUIRES: 'tensor' must be a Tensor of a copy-able type if it is not
62// appropriately memory-aligned.
63//
64// Split() and Concat() are inverse operations.
65Status Split(const Tensor& tensor, const gtl::ArraySlice<int64_t>& sizes,
66 std::vector<Tensor>* result) TF_MUST_USE_RESULT;
67
68namespace internal {
69void SetTensorProtoShape(std::vector<size_t> shape,
70 TensorShapeProto* shape_proto);
71
72template <typename Type>
73class TensorProtoFieldHelper : public std::false_type {};
74
75#define DEFINE_PROTO_FIELD_HELPER(TYPE, FIELDNAME) \
76 template <> \
77 class TensorProtoFieldHelper<TYPE> : public std::true_type { \
78 public: \
79 typedef decltype( \
80 std::declval<TensorProto>().FIELDNAME##_val(0)) FieldType; \
81 typedef decltype( \
82 std::declval<TensorProto>().FIELDNAME##_val()) RepeatedFieldType; \
83 typedef decltype(std::declval<TensorProto>().mutable_##FIELDNAME##_val()) \
84 MutableRepeatedFieldType; \
85 static MutableRepeatedFieldType GetMutableField(TensorProto* proto) { \
86 return proto->mutable_##FIELDNAME##_val(); \
87 } \
88 static RepeatedFieldType& GetField(const TensorProto& proto) { \
89 return proto.FIELDNAME##_val(); \
90 } \
91 }
92
93// The argument pairs in the following macro instantiations encode the
94// mapping from C++ type ($1) to repeated field name "$2_val" used for storing
95// values in TensorProto. See tensorflow/core/framework/tensor.proto.
96DEFINE_PROTO_FIELD_HELPER(float, float);
97DEFINE_PROTO_FIELD_HELPER(double, double);
98DEFINE_PROTO_FIELD_HELPER(int8, int);
99DEFINE_PROTO_FIELD_HELPER(uint8, int);
100DEFINE_PROTO_FIELD_HELPER(int16, int);
101DEFINE_PROTO_FIELD_HELPER(uint16, int);
102DEFINE_PROTO_FIELD_HELPER(int32, int);
103DEFINE_PROTO_FIELD_HELPER(uint32, uint32);
104DEFINE_PROTO_FIELD_HELPER(int64_t, int64);
105DEFINE_PROTO_FIELD_HELPER(uint64, uint64);
106DEFINE_PROTO_FIELD_HELPER(bool, bool);
107DEFINE_PROTO_FIELD_HELPER(qint8, int);
108DEFINE_PROTO_FIELD_HELPER(quint8, int);
109DEFINE_PROTO_FIELD_HELPER(qint16, int);
110DEFINE_PROTO_FIELD_HELPER(quint16, int);
111DEFINE_PROTO_FIELD_HELPER(qint32, int);
112DEFINE_PROTO_FIELD_HELPER(Eigen::half, half);
113DEFINE_PROTO_FIELD_HELPER(bfloat16, half);
114DEFINE_PROTO_FIELD_HELPER(complex64, scomplex);
115DEFINE_PROTO_FIELD_HELPER(complex128, dcomplex);
116
117#undef DEFINE_PROTO_HELPER
118
119template <typename T>
120struct CopyHelper {
121 template <typename SrcIter, typename DstIter>
122 static void ToArray(SrcIter begin, SrcIter end, DstIter dst) {
123 using SrcType = typename std::iterator_traits<SrcIter>::value_type;
124 using DstType = typename std::iterator_traits<DstIter>::value_type;
125 std::transform(begin, end, dst, [](const SrcType& x) -> DstType {
126 return static_cast<DstType>(x);
127 });
128 }
129 template <typename SrcIter>
130 static void ToArray(SrcIter begin, SrcIter end, SrcIter dst) {
131 std::copy(begin, end, dst);
132 }
133 template <typename SrcIter, typename DstIter>
134 static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
135 ToArray(begin, end, dst);
136 }
137};
138
139// Overloads for Eigen::half and bfloat16 that are 16 bits in size but are
140// stored in an int32 field.
141template <>
142struct CopyHelper<Eigen::half> {
143 template <typename SrcIter>
144 static void ToArray(SrcIter begin, SrcIter end, Eigen::half* dst) {
145 std::transform(begin, end, dst, [](int x) -> Eigen::half {
146 return Eigen::numext::bit_cast<Eigen::half>(static_cast<uint16>(x));
147 });
148 }
149 template <typename SrcIter, typename DstIter>
150 static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
151 std::transform(begin, end, dst, [](Eigen::half h) -> int {
152 return static_cast<int>(Eigen::numext::bit_cast<uint16>(h));
153 });
154 }
155};
156
157template <>
158struct CopyHelper<bfloat16> {
159 template <typename SrcIter>
160 static void ToArray(SrcIter begin, SrcIter end, bfloat16* dst) {
161 std::transform(begin, end, dst, [](int x) -> bfloat16 {
162 return Eigen::numext::bit_cast<bfloat16>(static_cast<uint16>(x));
163 });
164 }
165 template <typename SrcIter, typename DstIter>
166 static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
167 std::transform(begin, end, dst, [](bfloat16 bf16) -> int {
168 return static_cast<int>(Eigen::numext::bit_cast<uint16>(bf16));
169 });
170 }
171};
172
173// Overloads for complex types that store real and imaginary parts
174// at indices 2*i and 2*i+1 in float or double field.
175template <typename RealType>
176struct CopyHelper<std::complex<RealType>> {
177 template <typename SrcIter>
178 static void ToArray(SrcIter begin, SrcIter end, std::complex<RealType>* dst) {
179 RealType* real_dst = reinterpret_cast<RealType*>(dst);
180 std::copy(begin, end, real_dst);
181 }
182
183 template <typename SrcIter, typename DstIter>
184 static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
185 size_t n = std::distance(begin, end);
186 const RealType* real_begin = reinterpret_cast<const RealType*>(&(*begin));
187 std::copy_n(real_begin, 2 * n, dst);
188 }
189};
190
191// Helper class to extract and insert values into TensorProto represented as
192// repeated fields.
193template <typename T>
194class TensorProtoHelper : public std::true_type {
195 public:
196 using FieldHelper = TensorProtoFieldHelper<T>;
197 using FieldType = typename TensorProtoFieldHelper<T>::FieldType;
198
199 static DataType GetDataType() { return DataTypeToEnum<T>::value; }
200
201 // Returns the number of values of type T encoded in the proto.
202 static size_t NumValues(const TensorProto& proto) {
203 size_t raw_size = FieldHelper::GetField(proto).size();
204 return is_complex<T>::value ? raw_size / 2 : raw_size;
205 }
206
207 static void AddValue(const T& value, TensorProto* proto) {
208 const T* val_ptr = &value;
209 AddValues(val_ptr, val_ptr + 1, proto);
210 }
211
212 static T GetValue(size_t index, const TensorProto& proto) {
213 const size_t stride = is_complex<T>::value ? 2 : 1;
214 T val;
215 CopyHelper<T>::ToArray(
216 FieldHelper::GetField(proto).begin() + stride * index,
217 FieldHelper::GetField(proto).begin() + stride * (index + 1), &val);
218 return val;
219 }
220
221 template <typename IterType>
222 static void AddValues(IterType begin, IterType end, TensorProto* proto) {
223 size_t n = std::distance(begin, end);
224 FieldType* dst = AppendUninitialized(n, proto);
225 CopyHelper<T>::FromArray(begin, end, dst);
226 }
227
228 template <typename IterType>
229 static void CopyValues(IterType dst, const TensorProto& proto) {
230 CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin(),
231 FieldHelper::GetField(proto).end(), dst);
232 }
233
234 static void Truncate(size_t new_size, TensorProto* proto) {
235 if (is_complex<T>::value) new_size *= 2;
236 FieldHelper::GetMutableField(proto)->Truncate(new_size);
237 }
238
239 static FieldType* AppendUninitialized(size_t n, TensorProto* proto) {
240 if (is_complex<T>::value) n *= 2;
241 auto* field = FieldHelper::GetMutableField(proto);
242 field->Reserve(field->size() + n);
243 return reinterpret_cast<FieldType*>(field->AddNAlreadyReserved(n));
244 }
245};
246
247// Specialization for string.
248template <>
249class TensorProtoHelper<string> : public std::true_type {
250 public:
251 static DataType GetDataType() { return DataType::DT_STRING; }
252 static void AddValue(const string& value, TensorProto* proto) {
253 *proto->mutable_string_val()->Add() = value;
254 }
255 template <typename IterType>
256 static void AddValues(IterType begin, IterType end, TensorProto* proto) {
257 for (IterType it = begin; it != end; ++it) {
258 AddValue(*it, proto);
259 }
260 }
261 template <typename IterType>
262 static void CopyToTensorContent(IterType begin, IterType end,
263 TensorProto* proto) {
264 AddValues(begin, end, proto);
265 }
266};
267
268} // namespace internal
269
270// Creates a 'TensorProto' with specified shape and values.
271// The dtype and a field to represent data values of the returned 'TensorProto'
272// are determined based on type of the 'values' parameter.
273template <typename Type>
274typename std::enable_if<internal::TensorProtoHelper<Type>::value,
275 TensorProto>::type
276CreateTensorProto(const std::vector<Type>& values,
277 const std::vector<size_t>& shape) {
278 TensorProto tensor;
279 TensorShapeProto tensor_shape_proto;
280 internal::SetTensorProtoShape(shape, &tensor_shape_proto);
281 if (TensorShape(tensor_shape_proto).num_elements() != values.size()) {
282 LOG(ERROR) << "Shape and number of values (" << values.size()
283 << ") are incompatible.";
284 return tensor;
285 }
286 using TypeHelper = internal::TensorProtoHelper<Type>;
287 tensor.set_dtype(TypeHelper::GetDataType());
288 tensor.mutable_tensor_shape()->Swap(&tensor_shape_proto);
289 TypeHelper::AddValues(values.begin(), values.end(), &tensor);
290 return tensor;
291}
292
293// Converts values in tensor to run-length encoded compressed form.
294//
295// The elements of a tensor can be stored in a TensorProto in one of the
296// following two forms:
297// 1. As a raw byte string in the field `tensor_content` containing the
298// serialized in-memory representation of the tensor.
299// 2. As values of a repeated field depending on the datatype, e.g. that
300// values of a DT_FLOAT tensor would be stored in the repeated field
301// `float_val`.
302// Storage scheme 2 may use a simple form of run-length encoding to compress
303// data: If the values contains a tail of identical values, the repeated field
304// will be truncated such that the number of values in the repeated field is
305// less than the number of elements implied by the field`tensor_shape`. The
306// original tensor can be recovered by repeating the final value in the repeated
307// field.
308//
309// The TensorProto will be compressed if a) the tensor contains at least
310// min_num_elements elements and b) the compressed tensor proto is would be at
311// most the size of the original tensor proto divided by min_compression_ratio.
312//
313// Returns true if the tensor was compressed.
314bool CompressTensorProtoInPlace(int64_t min_num_elements,
315 float min_compression_ratio,
316 TensorProto* tensor);
317
318inline bool CompressTensorProtoInPlace(TensorProto* tensor) {
319 static const int64_t kDefaultMinNumElements = 64;
320 static const float kDefaultMinCompressionRatio = 2.0f;
321 return CompressTensorProtoInPlace(kDefaultMinNumElements,
322 kDefaultMinCompressionRatio, tensor);
323}
324
325// Make a TensorShape from the contents of shape_t. Shape_t must be a
326// 1-dimensional tensor of type int32 or int64.
327Status MakeShape(const Tensor& shape_t, TensorShape* out);
328
329} // namespace tensor
330} // namespace tensorflow
331
332#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
333