1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/tensor_util.h" |
17 | |
18 | #include <cmath> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/framework/tensor.h" |
22 | #include "tensorflow/core/framework/tensor_shape.h" |
23 | #include "tensorflow/core/framework/type_traits.h" |
24 | #include "tensorflow/core/framework/variant.h" |
25 | #include "tensorflow/core/lib/core/stringpiece.h" |
26 | #include "tensorflow/core/platform/protobuf.h" |
27 | #include "tensorflow/core/platform/tensor_coding.h" |
28 | #include "tensorflow/core/platform/types.h" |
29 | |
30 | namespace tensorflow { |
31 | namespace tensor { |
32 | |
33 | Tensor DeepCopy(const Tensor& other) { |
34 | Tensor tmp = Tensor(other.dtype(), other.shape()); |
35 | DeepCopy(other, &tmp); |
36 | return tmp; |
37 | } |
38 | |
39 | void DeepCopy(const Tensor& input, Tensor* output) { |
40 | if (DataTypeCanUseMemcpy(input.dtype())) { |
41 | if (input.NumElements() > 0) { |
42 | StringPiece input_data = input.tensor_data(); |
43 | |
44 | // We use StringPiece as a convenient map over the tensor buffer, |
45 | // but we cast the type to get to the underlying buffer to do the |
46 | // copy. |
47 | StringPiece output_data = output->tensor_data(); |
48 | memcpy(const_cast<char*>(output_data.data()), input_data.data(), |
49 | input_data.size()); |
50 | } |
51 | } else if (input.dtype() == DT_STRING) { |
52 | output->unaligned_flat<tstring>() = input.unaligned_flat<tstring>(); |
53 | } else { |
54 | CHECK_EQ(DT_VARIANT, input.dtype()); |
55 | output->unaligned_flat<Variant>() = input.unaligned_flat<Variant>(); |
56 | } |
57 | } |
58 | |
59 | Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) { |
60 | if (tensors.empty()) { |
61 | return errors::InvalidArgument("Cannot concatenate zero tensors" ); |
62 | } |
63 | int64_t total_dim0_size = 0; |
64 | for (const Tensor& tensor : tensors) { |
65 | if (tensor.dims() == 0) { |
66 | return errors::InvalidArgument( |
67 | "Cannot concatenate a zero-dimensional tensor" ); |
68 | } |
69 | total_dim0_size += tensor.dim_size(0); |
70 | } |
71 | TensorShape shape = tensors[0].shape(); |
72 | shape.set_dim(0, total_dim0_size); |
73 | |
74 | const DataType dtype = tensors[0].dtype(); |
75 | for (int i = 1; i < tensors.size(); ++i) { |
76 | if (tensors[i].dtype() != dtype) { |
77 | return errors::InvalidArgument( |
78 | "Cannot concatenate tensors that have different data types." , " Got " , |
79 | DataTypeString(dtype), " and " , DataTypeString(tensors[i].dtype()), |
80 | "." ); |
81 | } |
82 | } |
83 | *result = Tensor(dtype, shape); |
84 | |
85 | // We use StringPiece as a convenient map over the tensor buffer, |
86 | // but we cast the type to get to the underlying buffer to do the |
87 | // copy. |
88 | StringPiece to_data = result->tensor_data(); |
89 | |
90 | if (DataTypeCanUseMemcpy(dtype)) { |
91 | int64_t offset = 0; |
92 | for (const Tensor& tensor : tensors) { |
93 | StringPiece from_data = tensor.tensor_data(); |
94 | CHECK_LE(offset + from_data.size(), to_data.size()); |
95 | memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(), |
96 | from_data.size()); |
97 | |
98 | offset += from_data.size(); |
99 | } |
100 | } else { |
101 | if (dtype != DT_STRING) { |
102 | return errors::Internal("Unexpected data type" ); |
103 | } |
104 | tstring* to_strings = |
105 | reinterpret_cast<tstring*>(const_cast<char*>(to_data.data())); |
106 | |
107 | int64_t offset = 0; |
108 | for (const Tensor& tensor : tensors) { |
109 | auto from_strings = tensor.flat<tstring>(); |
110 | CHECK_LE(offset + tensor.NumElements(), result->NumElements()); |
111 | for (int i = 0; i < tensor.NumElements(); ++i) { |
112 | to_strings[offset + i] = from_strings(i); |
113 | } |
114 | |
115 | offset += tensor.NumElements(); |
116 | } |
117 | } |
118 | |
119 | return OkStatus(); |
120 | } |
121 | |
122 | Status Split(const Tensor& tensor, const gtl::ArraySlice<int64_t>& sizes, |
123 | std::vector<Tensor>* result) { |
124 | if (tensor.dims() == 0) { |
125 | return errors::InvalidArgument("Cannot split a zero-dimensional tensor" ); |
126 | } |
127 | int64_t total_size = 0; |
128 | for (int64_t size : sizes) { |
129 | total_size += size; |
130 | } |
131 | if (total_size != tensor.dim_size(0)) { |
132 | return errors::InvalidArgument( |
133 | "The values in 'sizes' do not sum to the zeroth-dimension size of " |
134 | "'tensor'" ); |
135 | } |
136 | |
137 | StringPiece from_data = tensor.tensor_data(); |
138 | |
139 | if (DataTypeCanUseMemcpy(tensor.dtype())) { |
140 | int64_t offset = 0; |
141 | for (int64_t size : sizes) { |
142 | TensorShape shape = tensor.shape(); |
143 | shape.set_dim(0, size); |
144 | result->emplace_back(tensor.dtype(), shape); |
145 | Tensor* split = &(*result)[result->size() - 1]; |
146 | |
147 | // We use StringPiece as a convenient map over the tensor buffer, |
148 | // but we cast the type to get to the underlying buffer to do the |
149 | // copy. |
150 | StringPiece to_data = split->tensor_data(); |
151 | CHECK_LE(offset + to_data.size(), from_data.size()); |
152 | memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset, |
153 | to_data.size()); |
154 | |
155 | offset += to_data.size(); |
156 | } |
157 | } else { |
158 | if (tensor.dtype() != DT_STRING) { |
159 | return errors::Internal("Unexpected data type" ); |
160 | } |
161 | auto from_strings = tensor.flat<tstring>(); |
162 | |
163 | int64_t offset = 0; |
164 | for (int64_t size : sizes) { |
165 | TensorShape shape = tensor.shape(); |
166 | shape.set_dim(0, size); |
167 | result->emplace_back(tensor.dtype(), shape); |
168 | Tensor& split = (*result)[result->size() - 1]; |
169 | tstring* to_strings = reinterpret_cast<tstring*>( |
170 | const_cast<char*>(split.tensor_data().data())); |
171 | |
172 | CHECK_LE(offset + split.NumElements(), tensor.NumElements()); |
173 | for (int i = 0; i < split.NumElements(); ++i) { |
174 | to_strings[i] = from_strings(offset + i); |
175 | } |
176 | |
177 | offset += split.NumElements(); |
178 | } |
179 | } |
180 | |
181 | return OkStatus(); |
182 | } |
183 | |
184 | namespace internal { |
185 | void SetTensorProtoShape(std::vector<size_t> shape, |
186 | TensorShapeProto* shape_proto) { |
187 | for (auto dim : shape) { |
188 | shape_proto->mutable_dim()->Add()->set_size(dim); |
189 | } |
190 | } |
191 | |
192 | template <typename T> |
193 | bool CompressTensorContent(float min_compression_ratio, |
194 | const TensorShape& shape, TensorProto* tensor) { |
195 | using TypeHelper = internal::TensorProtoHelper<T>; |
196 | using FieldType = typename internal::TensorProtoHelper<T>::FieldType; |
197 | const int64_t num_tensor_values = shape.num_elements(); |
198 | const int64_t num_bytes = tensor->tensor_content().size(); |
199 | const int64_t num_raw_values = num_bytes / sizeof(T); |
200 | if (num_raw_values != num_tensor_values) { |
201 | // Invalid or too small. |
202 | return false; |
203 | } |
204 | int64_t last_offset = num_bytes - 1; |
205 | int64_t prev_offset = last_offset - sizeof(T); |
206 | // Inspect individual raw bytes sizeof(T) bytes apart in adjacent elements, |
207 | // starting from the end, to find the last pair of elements that are not |
208 | // identical. |
209 | while (prev_offset >= 0) { |
210 | if (tensor->tensor_content()[prev_offset] != |
211 | tensor->tensor_content()[last_offset]) { |
212 | break; |
213 | } |
214 | --last_offset; |
215 | --prev_offset; |
216 | } |
217 | if (prev_offset == -1) { |
218 | // It this is a splat of value 0, it does not need an explicit value, just |
219 | // erase the content. |
220 | T splat_value; |
221 | port::CopySubrangeToArray(tensor->tensor_content(), 0, sizeof(T), |
222 | reinterpret_cast<char*>(&splat_value)); |
223 | if (splat_value == T(0)) { |
224 | tensor->clear_tensor_content(); |
225 | return true; |
226 | } |
227 | } |
228 | // Round up to the next whole number of element of type T. |
229 | const int64_t new_num_values = last_offset / sizeof(T) + 1; |
230 | if (new_num_values * (is_complex<T>::value ? 2 : 1) * sizeof(FieldType) > |
231 | static_cast<int64_t>(num_bytes / min_compression_ratio)) { |
232 | return false; |
233 | } |
234 | // Copy values to truncated repeated field. |
235 | if constexpr (sizeof(FieldType) == sizeof(T)) { |
236 | FieldType* dst_ptr = |
237 | TypeHelper::AppendUninitialized(new_num_values, tensor); |
238 | port::CopySubrangeToArray(tensor->tensor_content(), 0, |
239 | new_num_values * sizeof(T), |
240 | reinterpret_cast<char*>(dst_ptr)); |
241 | tensor->clear_tensor_content(); |
242 | } else if constexpr (sizeof(T) > 1) { |
243 | // Copy raw bytes to temp array first, then cast. |
244 | gtl::InlinedVector<T, 64> tmp; |
245 | if (new_num_values >= tmp.max_size()) return false; |
246 | tmp.resize(new_num_values); |
247 | |
248 | port::CopySubrangeToArray(tensor->tensor_content(), 0, |
249 | new_num_values * sizeof(T), |
250 | reinterpret_cast<char*>(tmp.data())); |
251 | tensor->clear_tensor_content(); |
252 | TypeHelper::AddValues(tmp.begin(), tmp.end(), tensor); |
253 | } else { |
254 | // Copy and cast, one byte at a time. |
255 | for (int64_t i = 0; i < new_num_values; ++i) { |
256 | char c = tensor->tensor_content()[i]; |
257 | TypeHelper::AddValue(static_cast<T>(c), tensor); |
258 | } |
259 | tensor->clear_tensor_content(); |
260 | } |
261 | return true; |
262 | } |
263 | |
264 | template <typename T> |
265 | inline bool PackedValuesNotEqual(T a, T b) { |
266 | return a != b; |
267 | } |
268 | template <> |
269 | inline bool PackedValuesNotEqual(float a, float b) { |
270 | return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b); |
271 | } |
272 | template <> |
273 | inline bool PackedValuesNotEqual(double a, double b) { |
274 | return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b); |
275 | } |
276 | template <typename RealType> |
277 | inline bool PackedValuesNotEqual(const std::complex<RealType>& a, |
278 | const std::complex<RealType>& b) { |
279 | return PackedValuesNotEqual(a.real(), b.real()) || |
280 | PackedValuesNotEqual(a.imag(), b.imag()); |
281 | } |
282 | |
283 | // Integer can't be negative zero. |
284 | template <typename T, |
285 | typename std::enable_if<std::is_integral<T>::value>::type* = nullptr> |
286 | static bool IsNegativeZero(T value) { |
287 | return false; |
288 | } |
289 | |
290 | template <typename T, |
291 | typename std::enable_if<!std::is_integral<T>::value>::type* = nullptr> |
292 | static bool IsNegativeZero(T value) { |
293 | return value == T(0) && std::signbit(value); |
294 | } |
295 | |
296 | template <typename T> |
297 | static bool IsNegativeZero(std::complex<T> value) { |
298 | return IsNegativeZero(value.real()) || IsNegativeZero(value.imag()); |
299 | } |
300 | |
301 | static bool IsNegativeZero(Eigen::QUInt8 value) { return false; } |
302 | static bool IsNegativeZero(Eigen::QInt8 value) { return false; } |
303 | static bool IsNegativeZero(Eigen::QUInt16 value) { return false; } |
304 | static bool IsNegativeZero(Eigen::QInt16 value) { return false; } |
305 | static bool IsNegativeZero(Eigen::QInt32 value) { return false; } |
306 | static bool IsNegativeZero(Eigen::half value) { |
307 | return IsNegativeZero<float>(value); |
308 | } |
309 | static bool IsNegativeZero(Eigen::bfloat16 value) { |
310 | return IsNegativeZero<float>(value); |
311 | } |
312 | |
313 | template <typename T> |
314 | bool CompressRepeatedField(float min_compression_ratio, |
315 | const TensorShape& shape, TensorProto* tensor) { |
316 | using TypeHelper = internal::TensorProtoHelper<T>; |
317 | using FieldType = typename internal::TensorProtoHelper<T>::FieldType; |
318 | const int64_t num_tensor_values = shape.num_elements(); |
319 | const int64_t num_proto_values = TypeHelper::NumValues(*tensor); |
320 | |
321 | // Notice that for complex types the tensor is stored as an array of up to |
322 | // 2 * num_tensor_values real values (real and imaginary parts), possibly |
323 | // truncated. A 0-splat does not need any value present and is maximally |
324 | // compressed. |
325 | if (num_proto_values == 0) return false; |
326 | |
327 | const T last_value = TypeHelper::GetValue(num_proto_values - 1, *tensor); |
328 | int64_t last_index = 0; |
329 | for (int64_t i = num_proto_values - 2; i >= 0 && last_index == 0; --i) { |
330 | const T cur_value = TypeHelper::GetValue(i, *tensor); |
331 | if (PackedValuesNotEqual(cur_value, last_value)) { |
332 | last_index = i + 1; |
333 | } |
334 | } |
335 | |
336 | // Detect all zeroes tensors: this is default value and the content can be |
337 | // erased entirely. |
338 | if (last_index == 0 && last_value == T(0) && !IsNegativeZero(last_value)) { |
339 | TypeHelper::Truncate(0, tensor); |
340 | return true; |
341 | } |
342 | |
343 | const int64_t num_truncated_proto_values = last_index + 1; |
344 | const int64_t num_bytes_as_field = |
345 | num_truncated_proto_values * sizeof(FieldType); |
346 | const int64_t num_bytes_as_tensor_content = num_tensor_values * sizeof(T); |
347 | const int64_t num_bytes_before = num_proto_values * sizeof(FieldType); |
348 | if (std::min(num_bytes_as_field, num_bytes_as_tensor_content) > |
349 | static_cast<int64_t>(num_bytes_before / min_compression_ratio)) { |
350 | return false; |
351 | } |
352 | if (num_bytes_as_field <= num_bytes_as_tensor_content) { |
353 | TypeHelper::Truncate(num_truncated_proto_values, tensor); |
354 | } else { |
355 | gtl::InlinedVector<T, 64> tmp; |
356 | if (num_proto_values == 1) { |
357 | // Splat case. |
358 | tmp.resize(num_tensor_values, last_value); |
359 | } else { |
360 | tmp.resize(num_tensor_values, T(0)); |
361 | TypeHelper::CopyValues(tmp.begin(), *tensor); |
362 | } |
363 | TypeHelper::Truncate(0, tensor); |
364 | port::CopyFromArray(tensor->mutable_tensor_content(), |
365 | reinterpret_cast<const char*>(tmp.data()), |
366 | num_bytes_as_tensor_content); |
367 | } |
368 | return true; |
369 | } |
370 | |
371 | template <typename T> |
372 | bool CompressTensorProtoInPlaceImpl(int64_t min_num_elements, |
373 | float min_compression_ratio, |
374 | TensorProto* tensor) { |
375 | const TensorShape shape(tensor->tensor_shape()); |
376 | const int64_t num_tensor_values = shape.num_elements(); |
377 | if (num_tensor_values < min_num_elements) { |
378 | return false; |
379 | } |
380 | if (tensor->tensor_content().empty()) { |
381 | return CompressRepeatedField<T>(min_compression_ratio, shape, tensor); |
382 | } else { |
383 | return CompressTensorContent<T>(min_compression_ratio, shape, tensor); |
384 | } |
385 | return true; |
386 | } |
387 | |
388 | } // namespace internal |
389 | |
390 | #define HANDLE_COMPRESS_CASE(TF_TYPE) \ |
391 | case TF_TYPE: \ |
392 | return internal::CompressTensorProtoInPlaceImpl< \ |
393 | EnumToDataType<TF_TYPE>::Type>(min_num_elements, \ |
394 | min_compression_ratio, tensor); \ |
395 | break |
396 | |
397 | bool CompressTensorProtoInPlace(int64_t min_num_elements, |
398 | float min_compression_ratio, |
399 | TensorProto* tensor) { |
400 | switch (tensor->dtype()) { |
401 | HANDLE_COMPRESS_CASE(DT_FLOAT); |
402 | HANDLE_COMPRESS_CASE(DT_DOUBLE); |
403 | HANDLE_COMPRESS_CASE(DT_COMPLEX64); |
404 | HANDLE_COMPRESS_CASE(DT_COMPLEX128); |
405 | HANDLE_COMPRESS_CASE(DT_UINT8); |
406 | HANDLE_COMPRESS_CASE(DT_INT8); |
407 | HANDLE_COMPRESS_CASE(DT_UINT16); |
408 | HANDLE_COMPRESS_CASE(DT_INT16); |
409 | HANDLE_COMPRESS_CASE(DT_UINT32); |
410 | HANDLE_COMPRESS_CASE(DT_INT32); |
411 | HANDLE_COMPRESS_CASE(DT_UINT64); |
412 | HANDLE_COMPRESS_CASE(DT_INT64); |
413 | HANDLE_COMPRESS_CASE(DT_BOOL); |
414 | HANDLE_COMPRESS_CASE(DT_QUINT8); |
415 | HANDLE_COMPRESS_CASE(DT_QINT8); |
416 | HANDLE_COMPRESS_CASE(DT_QUINT16); |
417 | HANDLE_COMPRESS_CASE(DT_QINT16); |
418 | HANDLE_COMPRESS_CASE(DT_QINT32); |
419 | HANDLE_COMPRESS_CASE(DT_HALF); |
420 | HANDLE_COMPRESS_CASE(DT_BFLOAT16); |
421 | default: |
422 | return false; |
423 | } |
424 | } |
425 | |
426 | #undef HANDLE_COMPRESS_CASE |
427 | |
428 | Status MakeShape(const Tensor& shape, TensorShape* out) { |
429 | if (!TensorShapeUtils::IsVector(shape.shape())) { |
430 | return errors::InvalidArgument( |
431 | "shape must be a vector of {int32,int64}, got shape " , |
432 | shape.shape().DebugString()); |
433 | } |
434 | if (shape.dtype() == DataType::DT_INT32) { |
435 | auto vec = shape.flat<int32>(); |
436 | return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out); |
437 | } else if (shape.dtype() == DataType::DT_INT64) { |
438 | auto vec = shape.flat<int64_t>(); |
439 | return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out); |
440 | } else { |
441 | return errors::InvalidArgument("shape must be a vector of {int32,int64}." ); |
442 | } |
443 | } |
444 | |
445 | } // namespace tensor |
446 | } // namespace tensorflow |
447 | |