1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/data/compression_utils.h"
16
17#include <limits>
18
19#include "tensorflow/core/common_runtime/dma_helper.h"
20#include "tensorflow/core/framework/tensor.pb.h"
21#include "tensorflow/core/platform/snappy.h"
22#include "tensorflow/core/platform/types.h"
23
24namespace tensorflow {
25namespace data {
26
27Status CompressElement(const std::vector<Tensor>& element,
28 CompressedElement* out) {
29 // Step 1: Determine the total uncompressed size. This requires serializing
30 // non-memcopyable tensors, which we save to use again later.
31 std::vector<TensorProto> non_memcpy_components;
32 size_t total_size = 0;
33 for (auto& component : element) {
34 if (DataTypeCanUseMemcpy(component.dtype())) {
35 const TensorBuffer* buffer = DMAHelper::buffer(&component);
36 if (buffer) {
37 total_size += buffer->size();
38 }
39 } else {
40 non_memcpy_components.emplace_back();
41 component.AsProtoTensorContent(&non_memcpy_components.back());
42 total_size += non_memcpy_components.back().ByteSizeLong();
43 }
44 }
45
46 // Step 2: Write the tensor data to a buffer, and compress that buffer.
47 // We use tstring for access to resize_uninitialized.
48 tstring uncompressed;
49 uncompressed.resize_uninitialized(total_size);
50 // Position in `uncompressed` to write the next component.
51 char* position = uncompressed.mdata();
52 int non_memcpy_component_index = 0;
53 for (auto& component : element) {
54 CompressedComponentMetadata* metadata =
55 out->mutable_component_metadata()->Add();
56 metadata->set_dtype(component.dtype());
57 component.shape().AsProto(metadata->mutable_tensor_shape());
58 if (DataTypeCanUseMemcpy(component.dtype())) {
59 const TensorBuffer* buffer = DMAHelper::buffer(&component);
60 if (buffer) {
61 memcpy(position, buffer->data(), buffer->size());
62 metadata->set_tensor_size_bytes(buffer->size());
63 }
64 } else {
65 TensorProto& proto = non_memcpy_components[non_memcpy_component_index++];
66 proto.SerializeToArray(position, proto.ByteSizeLong());
67 metadata->set_tensor_size_bytes(proto.ByteSizeLong());
68 }
69 position += metadata->tensor_size_bytes();
70 }
71 if (total_size > kuint32max) {
72 return errors::OutOfRange("Encountered dataset element of size ",
73 total_size, ", exceeding the 4GB Snappy limit.");
74 }
75 DCHECK_EQ(position, uncompressed.mdata() + total_size);
76
77 if (!port::Snappy_Compress(uncompressed.mdata(), total_size,
78 out->mutable_data())) {
79 return errors::Internal("Failed to compress using snappy.");
80 }
81 VLOG(3) << "Compressed element from " << total_size << " bytes to "
82 << out->data().size() << " bytes";
83 return OkStatus();
84}
85
86Status UncompressElement(const CompressedElement& compressed,
87 std::vector<Tensor>* out) {
88 int num_components = compressed.component_metadata_size();
89 out->clear();
90 out->reserve(num_components);
91
92 // Step 1: Prepare the memory that we will uncompress into.
93 std::vector<struct iovec> iov(num_components);
94 // We use tstring for access to resize_uninitialized.
95 std::vector<tstring> tensor_proto_strs;
96 // num_components is a conservative estimate. It is important to reserve
97 // vector space so that the vector doesn't resize itself, which could
98 // invalidate pointers to its strings' data.
99 tensor_proto_strs.reserve(num_components);
100 int64_t total_size = 0;
101 for (int i = 0; i < num_components; ++i) {
102 const CompressedComponentMetadata& metadata =
103 compressed.component_metadata(i);
104 if (DataTypeCanUseMemcpy(metadata.dtype())) {
105 out->emplace_back(metadata.dtype(), metadata.tensor_shape());
106 TensorBuffer* buffer = DMAHelper::buffer(&out->back());
107 if (buffer) {
108 iov[i].iov_base = buffer->data();
109 iov[i].iov_len = buffer->size();
110 } else {
111 iov[i].iov_base = nullptr;
112 iov[i].iov_len = 0;
113 }
114 } else {
115 // Allocate an empty Tensor. We will fill it out later after
116 // uncompressing into the tensor_proto_str.
117 out->emplace_back();
118 tensor_proto_strs.emplace_back();
119 tstring& tensor_proto_str = tensor_proto_strs.back();
120 tensor_proto_str.resize_uninitialized(metadata.tensor_size_bytes());
121 iov[i].iov_base = tensor_proto_str.mdata();
122 iov[i].iov_len = tensor_proto_str.size();
123 }
124 total_size += iov[i].iov_len;
125 }
126
127 // Step 2: Uncompress into the iovec.
128 const std::string& compressed_data = compressed.data();
129 size_t uncompressed_size;
130 if (!port::Snappy_GetUncompressedLength(
131 compressed_data.data(), compressed_data.size(), &uncompressed_size)) {
132 return errors::Internal(
133 "Could not get snappy uncompressed length. Compressed data size: ",
134 compressed_data.size());
135 }
136 if (uncompressed_size != static_cast<size_t>(total_size)) {
137 return errors::Internal(
138 "Uncompressed size mismatch. Snappy expects ", uncompressed_size,
139 " whereas the tensor metadata suggests ", total_size);
140 }
141 if (!port::Snappy_UncompressToIOVec(compressed_data.data(),
142 compressed_data.size(), iov.data(),
143 num_components)) {
144 return errors::Internal("Failed to perform snappy decompression.");
145 }
146
147 // Step 3: Deserialize tensor proto strings to tensors.
148 int tensor_proto_strs_index = 0;
149 for (int i = 0; i < num_components; ++i) {
150 if (DataTypeCanUseMemcpy(compressed.component_metadata(i).dtype())) {
151 continue;
152 }
153 TensorProto tp;
154 if (!tp.ParseFromString(tensor_proto_strs[tensor_proto_strs_index++])) {
155 return errors::Internal("Could not parse TensorProto");
156 }
157 if (!out->at(i).FromProto(tp)) {
158 return errors::Internal("Could not parse Tensor");
159 }
160 }
161 return OkStatus();
162}
163
164} // namespace data
165} // namespace tensorflow
166