1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/tensor_util.h"
17
18#include <cmath>
19#include <vector>
20
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/framework/tensor_shape.h"
23#include "tensorflow/core/framework/type_traits.h"
24#include "tensorflow/core/framework/variant.h"
25#include "tensorflow/core/lib/core/stringpiece.h"
26#include "tensorflow/core/platform/protobuf.h"
27#include "tensorflow/core/platform/tensor_coding.h"
28#include "tensorflow/core/platform/types.h"
29
30namespace tensorflow {
31namespace tensor {
32
33Tensor DeepCopy(const Tensor& other) {
34 Tensor tmp = Tensor(other.dtype(), other.shape());
35 DeepCopy(other, &tmp);
36 return tmp;
37}
38
39void DeepCopy(const Tensor& input, Tensor* output) {
40 if (DataTypeCanUseMemcpy(input.dtype())) {
41 if (input.NumElements() > 0) {
42 StringPiece input_data = input.tensor_data();
43
44 // We use StringPiece as a convenient map over the tensor buffer,
45 // but we cast the type to get to the underlying buffer to do the
46 // copy.
47 StringPiece output_data = output->tensor_data();
48 memcpy(const_cast<char*>(output_data.data()), input_data.data(),
49 input_data.size());
50 }
51 } else if (input.dtype() == DT_STRING) {
52 output->unaligned_flat<tstring>() = input.unaligned_flat<tstring>();
53 } else {
54 CHECK_EQ(DT_VARIANT, input.dtype());
55 output->unaligned_flat<Variant>() = input.unaligned_flat<Variant>();
56 }
57}
58
59Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
60 if (tensors.empty()) {
61 return errors::InvalidArgument("Cannot concatenate zero tensors");
62 }
63 int64_t total_dim0_size = 0;
64 for (const Tensor& tensor : tensors) {
65 if (tensor.dims() == 0) {
66 return errors::InvalidArgument(
67 "Cannot concatenate a zero-dimensional tensor");
68 }
69 total_dim0_size += tensor.dim_size(0);
70 }
71 TensorShape shape = tensors[0].shape();
72 shape.set_dim(0, total_dim0_size);
73
74 const DataType dtype = tensors[0].dtype();
75 for (int i = 1; i < tensors.size(); ++i) {
76 if (tensors[i].dtype() != dtype) {
77 return errors::InvalidArgument(
78 "Cannot concatenate tensors that have different data types.", " Got ",
79 DataTypeString(dtype), " and ", DataTypeString(tensors[i].dtype()),
80 ".");
81 }
82 }
83 *result = Tensor(dtype, shape);
84
85 // We use StringPiece as a convenient map over the tensor buffer,
86 // but we cast the type to get to the underlying buffer to do the
87 // copy.
88 StringPiece to_data = result->tensor_data();
89
90 if (DataTypeCanUseMemcpy(dtype)) {
91 int64_t offset = 0;
92 for (const Tensor& tensor : tensors) {
93 StringPiece from_data = tensor.tensor_data();
94 CHECK_LE(offset + from_data.size(), to_data.size());
95 memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(),
96 from_data.size());
97
98 offset += from_data.size();
99 }
100 } else {
101 if (dtype != DT_STRING) {
102 return errors::Internal("Unexpected data type");
103 }
104 tstring* to_strings =
105 reinterpret_cast<tstring*>(const_cast<char*>(to_data.data()));
106
107 int64_t offset = 0;
108 for (const Tensor& tensor : tensors) {
109 auto from_strings = tensor.flat<tstring>();
110 CHECK_LE(offset + tensor.NumElements(), result->NumElements());
111 for (int i = 0; i < tensor.NumElements(); ++i) {
112 to_strings[offset + i] = from_strings(i);
113 }
114
115 offset += tensor.NumElements();
116 }
117 }
118
119 return OkStatus();
120}
121
122Status Split(const Tensor& tensor, const gtl::ArraySlice<int64_t>& sizes,
123 std::vector<Tensor>* result) {
124 if (tensor.dims() == 0) {
125 return errors::InvalidArgument("Cannot split a zero-dimensional tensor");
126 }
127 int64_t total_size = 0;
128 for (int64_t size : sizes) {
129 total_size += size;
130 }
131 if (total_size != tensor.dim_size(0)) {
132 return errors::InvalidArgument(
133 "The values in 'sizes' do not sum to the zeroth-dimension size of "
134 "'tensor'");
135 }
136
137 StringPiece from_data = tensor.tensor_data();
138
139 if (DataTypeCanUseMemcpy(tensor.dtype())) {
140 int64_t offset = 0;
141 for (int64_t size : sizes) {
142 TensorShape shape = tensor.shape();
143 shape.set_dim(0, size);
144 result->emplace_back(tensor.dtype(), shape);
145 Tensor* split = &(*result)[result->size() - 1];
146
147 // We use StringPiece as a convenient map over the tensor buffer,
148 // but we cast the type to get to the underlying buffer to do the
149 // copy.
150 StringPiece to_data = split->tensor_data();
151 CHECK_LE(offset + to_data.size(), from_data.size());
152 memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset,
153 to_data.size());
154
155 offset += to_data.size();
156 }
157 } else {
158 if (tensor.dtype() != DT_STRING) {
159 return errors::Internal("Unexpected data type");
160 }
161 auto from_strings = tensor.flat<tstring>();
162
163 int64_t offset = 0;
164 for (int64_t size : sizes) {
165 TensorShape shape = tensor.shape();
166 shape.set_dim(0, size);
167 result->emplace_back(tensor.dtype(), shape);
168 Tensor& split = (*result)[result->size() - 1];
169 tstring* to_strings = reinterpret_cast<tstring*>(
170 const_cast<char*>(split.tensor_data().data()));
171
172 CHECK_LE(offset + split.NumElements(), tensor.NumElements());
173 for (int i = 0; i < split.NumElements(); ++i) {
174 to_strings[i] = from_strings(offset + i);
175 }
176
177 offset += split.NumElements();
178 }
179 }
180
181 return OkStatus();
182}
183
184namespace internal {
185void SetTensorProtoShape(std::vector<size_t> shape,
186 TensorShapeProto* shape_proto) {
187 for (auto dim : shape) {
188 shape_proto->mutable_dim()->Add()->set_size(dim);
189 }
190}
191
192template <typename T>
193bool CompressTensorContent(float min_compression_ratio,
194 const TensorShape& shape, TensorProto* tensor) {
195 using TypeHelper = internal::TensorProtoHelper<T>;
196 using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
197 const int64_t num_tensor_values = shape.num_elements();
198 const int64_t num_bytes = tensor->tensor_content().size();
199 const int64_t num_raw_values = num_bytes / sizeof(T);
200 if (num_raw_values != num_tensor_values) {
201 // Invalid or too small.
202 return false;
203 }
204 int64_t last_offset = num_bytes - 1;
205 int64_t prev_offset = last_offset - sizeof(T);
206 // Inspect individual raw bytes sizeof(T) bytes apart in adjacent elements,
207 // starting from the end, to find the last pair of elements that are not
208 // identical.
209 while (prev_offset >= 0) {
210 if (tensor->tensor_content()[prev_offset] !=
211 tensor->tensor_content()[last_offset]) {
212 break;
213 }
214 --last_offset;
215 --prev_offset;
216 }
217 if (prev_offset == -1) {
218 // It this is a splat of value 0, it does not need an explicit value, just
219 // erase the content.
220 T splat_value;
221 port::CopySubrangeToArray(tensor->tensor_content(), 0, sizeof(T),
222 reinterpret_cast<char*>(&splat_value));
223 if (splat_value == T(0)) {
224 tensor->clear_tensor_content();
225 return true;
226 }
227 }
228 // Round up to the next whole number of element of type T.
229 const int64_t new_num_values = last_offset / sizeof(T) + 1;
230 if (new_num_values * (is_complex<T>::value ? 2 : 1) * sizeof(FieldType) >
231 static_cast<int64_t>(num_bytes / min_compression_ratio)) {
232 return false;
233 }
234 // Copy values to truncated repeated field.
235 if constexpr (sizeof(FieldType) == sizeof(T)) {
236 FieldType* dst_ptr =
237 TypeHelper::AppendUninitialized(new_num_values, tensor);
238 port::CopySubrangeToArray(tensor->tensor_content(), 0,
239 new_num_values * sizeof(T),
240 reinterpret_cast<char*>(dst_ptr));
241 tensor->clear_tensor_content();
242 } else if constexpr (sizeof(T) > 1) {
243 // Copy raw bytes to temp array first, then cast.
244 gtl::InlinedVector<T, 64> tmp;
245 if (new_num_values >= tmp.max_size()) return false;
246 tmp.resize(new_num_values);
247
248 port::CopySubrangeToArray(tensor->tensor_content(), 0,
249 new_num_values * sizeof(T),
250 reinterpret_cast<char*>(tmp.data()));
251 tensor->clear_tensor_content();
252 TypeHelper::AddValues(tmp.begin(), tmp.end(), tensor);
253 } else {
254 // Copy and cast, one byte at a time.
255 for (int64_t i = 0; i < new_num_values; ++i) {
256 char c = tensor->tensor_content()[i];
257 TypeHelper::AddValue(static_cast<T>(c), tensor);
258 }
259 tensor->clear_tensor_content();
260 }
261 return true;
262}
263
264template <typename T>
265inline bool PackedValuesNotEqual(T a, T b) {
266 return a != b;
267}
268template <>
269inline bool PackedValuesNotEqual(float a, float b) {
270 return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
271}
272template <>
273inline bool PackedValuesNotEqual(double a, double b) {
274 return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
275}
276template <typename RealType>
277inline bool PackedValuesNotEqual(const std::complex<RealType>& a,
278 const std::complex<RealType>& b) {
279 return PackedValuesNotEqual(a.real(), b.real()) ||
280 PackedValuesNotEqual(a.imag(), b.imag());
281}
282
283// Integer can't be negative zero.
284template <typename T,
285 typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
286static bool IsNegativeZero(T value) {
287 return false;
288}
289
290template <typename T,
291 typename std::enable_if<!std::is_integral<T>::value>::type* = nullptr>
292static bool IsNegativeZero(T value) {
293 return value == T(0) && std::signbit(value);
294}
295
296template <typename T>
297static bool IsNegativeZero(std::complex<T> value) {
298 return IsNegativeZero(value.real()) || IsNegativeZero(value.imag());
299}
300
301static bool IsNegativeZero(Eigen::QUInt8 value) { return false; }
302static bool IsNegativeZero(Eigen::QInt8 value) { return false; }
303static bool IsNegativeZero(Eigen::QUInt16 value) { return false; }
304static bool IsNegativeZero(Eigen::QInt16 value) { return false; }
305static bool IsNegativeZero(Eigen::QInt32 value) { return false; }
306static bool IsNegativeZero(Eigen::half value) {
307 return IsNegativeZero<float>(value);
308}
309static bool IsNegativeZero(Eigen::bfloat16 value) {
310 return IsNegativeZero<float>(value);
311}
312
313template <typename T>
314bool CompressRepeatedField(float min_compression_ratio,
315 const TensorShape& shape, TensorProto* tensor) {
316 using TypeHelper = internal::TensorProtoHelper<T>;
317 using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
318 const int64_t num_tensor_values = shape.num_elements();
319 const int64_t num_proto_values = TypeHelper::NumValues(*tensor);
320
321 // Notice that for complex types the tensor is stored as an array of up to
322 // 2 * num_tensor_values real values (real and imaginary parts), possibly
323 // truncated. A 0-splat does not need any value present and is maximally
324 // compressed.
325 if (num_proto_values == 0) return false;
326
327 const T last_value = TypeHelper::GetValue(num_proto_values - 1, *tensor);
328 int64_t last_index = 0;
329 for (int64_t i = num_proto_values - 2; i >= 0 && last_index == 0; --i) {
330 const T cur_value = TypeHelper::GetValue(i, *tensor);
331 if (PackedValuesNotEqual(cur_value, last_value)) {
332 last_index = i + 1;
333 }
334 }
335
336 // Detect all zeroes tensors: this is default value and the content can be
337 // erased entirely.
338 if (last_index == 0 && last_value == T(0) && !IsNegativeZero(last_value)) {
339 TypeHelper::Truncate(0, tensor);
340 return true;
341 }
342
343 const int64_t num_truncated_proto_values = last_index + 1;
344 const int64_t num_bytes_as_field =
345 num_truncated_proto_values * sizeof(FieldType);
346 const int64_t num_bytes_as_tensor_content = num_tensor_values * sizeof(T);
347 const int64_t num_bytes_before = num_proto_values * sizeof(FieldType);
348 if (std::min(num_bytes_as_field, num_bytes_as_tensor_content) >
349 static_cast<int64_t>(num_bytes_before / min_compression_ratio)) {
350 return false;
351 }
352 if (num_bytes_as_field <= num_bytes_as_tensor_content) {
353 TypeHelper::Truncate(num_truncated_proto_values, tensor);
354 } else {
355 gtl::InlinedVector<T, 64> tmp;
356 if (num_proto_values == 1) {
357 // Splat case.
358 tmp.resize(num_tensor_values, last_value);
359 } else {
360 tmp.resize(num_tensor_values, T(0));
361 TypeHelper::CopyValues(tmp.begin(), *tensor);
362 }
363 TypeHelper::Truncate(0, tensor);
364 port::CopyFromArray(tensor->mutable_tensor_content(),
365 reinterpret_cast<const char*>(tmp.data()),
366 num_bytes_as_tensor_content);
367 }
368 return true;
369}
370
371template <typename T>
372bool CompressTensorProtoInPlaceImpl(int64_t min_num_elements,
373 float min_compression_ratio,
374 TensorProto* tensor) {
375 const TensorShape shape(tensor->tensor_shape());
376 const int64_t num_tensor_values = shape.num_elements();
377 if (num_tensor_values < min_num_elements) {
378 return false;
379 }
380 if (tensor->tensor_content().empty()) {
381 return CompressRepeatedField<T>(min_compression_ratio, shape, tensor);
382 } else {
383 return CompressTensorContent<T>(min_compression_ratio, shape, tensor);
384 }
385 return true;
386}
387
388} // namespace internal
389
390#define HANDLE_COMPRESS_CASE(TF_TYPE) \
391 case TF_TYPE: \
392 return internal::CompressTensorProtoInPlaceImpl< \
393 EnumToDataType<TF_TYPE>::Type>(min_num_elements, \
394 min_compression_ratio, tensor); \
395 break
396
397bool CompressTensorProtoInPlace(int64_t min_num_elements,
398 float min_compression_ratio,
399 TensorProto* tensor) {
400 switch (tensor->dtype()) {
401 HANDLE_COMPRESS_CASE(DT_FLOAT);
402 HANDLE_COMPRESS_CASE(DT_DOUBLE);
403 HANDLE_COMPRESS_CASE(DT_COMPLEX64);
404 HANDLE_COMPRESS_CASE(DT_COMPLEX128);
405 HANDLE_COMPRESS_CASE(DT_UINT8);
406 HANDLE_COMPRESS_CASE(DT_INT8);
407 HANDLE_COMPRESS_CASE(DT_UINT16);
408 HANDLE_COMPRESS_CASE(DT_INT16);
409 HANDLE_COMPRESS_CASE(DT_UINT32);
410 HANDLE_COMPRESS_CASE(DT_INT32);
411 HANDLE_COMPRESS_CASE(DT_UINT64);
412 HANDLE_COMPRESS_CASE(DT_INT64);
413 HANDLE_COMPRESS_CASE(DT_BOOL);
414 HANDLE_COMPRESS_CASE(DT_QUINT8);
415 HANDLE_COMPRESS_CASE(DT_QINT8);
416 HANDLE_COMPRESS_CASE(DT_QUINT16);
417 HANDLE_COMPRESS_CASE(DT_QINT16);
418 HANDLE_COMPRESS_CASE(DT_QINT32);
419 HANDLE_COMPRESS_CASE(DT_HALF);
420 HANDLE_COMPRESS_CASE(DT_BFLOAT16);
421 default:
422 return false;
423 }
424}
425
426#undef HANDLE_COMPRESS_CASE
427
428Status MakeShape(const Tensor& shape, TensorShape* out) {
429 if (!TensorShapeUtils::IsVector(shape.shape())) {
430 return errors::InvalidArgument(
431 "shape must be a vector of {int32,int64}, got shape ",
432 shape.shape().DebugString());
433 }
434 if (shape.dtype() == DataType::DT_INT32) {
435 auto vec = shape.flat<int32>();
436 return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
437 } else if (shape.dtype() == DataType::DT_INT64) {
438 auto vec = shape.flat<int64_t>();
439 return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
440 } else {
441 return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
442 }
443}
444
445} // namespace tensor
446} // namespace tensorflow
447