1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/data/compression_utils.h" |
16 | |
17 | #include <limits> |
18 | |
19 | #include "tensorflow/core/common_runtime/dma_helper.h" |
20 | #include "tensorflow/core/framework/tensor.pb.h" |
21 | #include "tensorflow/core/platform/snappy.h" |
22 | #include "tensorflow/core/platform/types.h" |
23 | |
24 | namespace tensorflow { |
25 | namespace data { |
26 | |
27 | Status CompressElement(const std::vector<Tensor>& element, |
28 | CompressedElement* out) { |
29 | // Step 1: Determine the total uncompressed size. This requires serializing |
30 | // non-memcopyable tensors, which we save to use again later. |
31 | std::vector<TensorProto> non_memcpy_components; |
32 | size_t total_size = 0; |
33 | for (auto& component : element) { |
34 | if (DataTypeCanUseMemcpy(component.dtype())) { |
35 | const TensorBuffer* buffer = DMAHelper::buffer(&component); |
36 | if (buffer) { |
37 | total_size += buffer->size(); |
38 | } |
39 | } else { |
40 | non_memcpy_components.emplace_back(); |
41 | component.AsProtoTensorContent(&non_memcpy_components.back()); |
42 | total_size += non_memcpy_components.back().ByteSizeLong(); |
43 | } |
44 | } |
45 | |
46 | // Step 2: Write the tensor data to a buffer, and compress that buffer. |
47 | // We use tstring for access to resize_uninitialized. |
48 | tstring uncompressed; |
49 | uncompressed.resize_uninitialized(total_size); |
50 | // Position in `uncompressed` to write the next component. |
51 | char* position = uncompressed.mdata(); |
52 | int non_memcpy_component_index = 0; |
53 | for (auto& component : element) { |
54 | CompressedComponentMetadata* metadata = |
55 | out->mutable_component_metadata()->Add(); |
56 | metadata->set_dtype(component.dtype()); |
57 | component.shape().AsProto(metadata->mutable_tensor_shape()); |
58 | if (DataTypeCanUseMemcpy(component.dtype())) { |
59 | const TensorBuffer* buffer = DMAHelper::buffer(&component); |
60 | if (buffer) { |
61 | memcpy(position, buffer->data(), buffer->size()); |
62 | metadata->set_tensor_size_bytes(buffer->size()); |
63 | } |
64 | } else { |
65 | TensorProto& proto = non_memcpy_components[non_memcpy_component_index++]; |
66 | proto.SerializeToArray(position, proto.ByteSizeLong()); |
67 | metadata->set_tensor_size_bytes(proto.ByteSizeLong()); |
68 | } |
69 | position += metadata->tensor_size_bytes(); |
70 | } |
71 | if (total_size > kuint32max) { |
72 | return errors::OutOfRange("Encountered dataset element of size " , |
73 | total_size, ", exceeding the 4GB Snappy limit." ); |
74 | } |
75 | DCHECK_EQ(position, uncompressed.mdata() + total_size); |
76 | |
77 | if (!port::Snappy_Compress(uncompressed.mdata(), total_size, |
78 | out->mutable_data())) { |
79 | return errors::Internal("Failed to compress using snappy." ); |
80 | } |
81 | VLOG(3) << "Compressed element from " << total_size << " bytes to " |
82 | << out->data().size() << " bytes" ; |
83 | return OkStatus(); |
84 | } |
85 | |
86 | Status UncompressElement(const CompressedElement& compressed, |
87 | std::vector<Tensor>* out) { |
88 | int num_components = compressed.component_metadata_size(); |
89 | out->clear(); |
90 | out->reserve(num_components); |
91 | |
92 | // Step 1: Prepare the memory that we will uncompress into. |
93 | std::vector<struct iovec> iov(num_components); |
94 | // We use tstring for access to resize_uninitialized. |
95 | std::vector<tstring> tensor_proto_strs; |
96 | // num_components is a conservative estimate. It is important to reserve |
97 | // vector space so that the vector doesn't resize itself, which could |
98 | // invalidate pointers to its strings' data. |
99 | tensor_proto_strs.reserve(num_components); |
100 | int64_t total_size = 0; |
101 | for (int i = 0; i < num_components; ++i) { |
102 | const CompressedComponentMetadata& metadata = |
103 | compressed.component_metadata(i); |
104 | if (DataTypeCanUseMemcpy(metadata.dtype())) { |
105 | out->emplace_back(metadata.dtype(), metadata.tensor_shape()); |
106 | TensorBuffer* buffer = DMAHelper::buffer(&out->back()); |
107 | if (buffer) { |
108 | iov[i].iov_base = buffer->data(); |
109 | iov[i].iov_len = buffer->size(); |
110 | } else { |
111 | iov[i].iov_base = nullptr; |
112 | iov[i].iov_len = 0; |
113 | } |
114 | } else { |
115 | // Allocate an empty Tensor. We will fill it out later after |
116 | // uncompressing into the tensor_proto_str. |
117 | out->emplace_back(); |
118 | tensor_proto_strs.emplace_back(); |
119 | tstring& tensor_proto_str = tensor_proto_strs.back(); |
120 | tensor_proto_str.resize_uninitialized(metadata.tensor_size_bytes()); |
121 | iov[i].iov_base = tensor_proto_str.mdata(); |
122 | iov[i].iov_len = tensor_proto_str.size(); |
123 | } |
124 | total_size += iov[i].iov_len; |
125 | } |
126 | |
127 | // Step 2: Uncompress into the iovec. |
128 | const std::string& compressed_data = compressed.data(); |
129 | size_t uncompressed_size; |
130 | if (!port::Snappy_GetUncompressedLength( |
131 | compressed_data.data(), compressed_data.size(), &uncompressed_size)) { |
132 | return errors::Internal( |
133 | "Could not get snappy uncompressed length. Compressed data size: " , |
134 | compressed_data.size()); |
135 | } |
136 | if (uncompressed_size != static_cast<size_t>(total_size)) { |
137 | return errors::Internal( |
138 | "Uncompressed size mismatch. Snappy expects " , uncompressed_size, |
139 | " whereas the tensor metadata suggests " , total_size); |
140 | } |
141 | if (!port::Snappy_UncompressToIOVec(compressed_data.data(), |
142 | compressed_data.size(), iov.data(), |
143 | num_components)) { |
144 | return errors::Internal("Failed to perform snappy decompression." ); |
145 | } |
146 | |
147 | // Step 3: Deserialize tensor proto strings to tensors. |
148 | int tensor_proto_strs_index = 0; |
149 | for (int i = 0; i < num_components; ++i) { |
150 | if (DataTypeCanUseMemcpy(compressed.component_metadata(i).dtype())) { |
151 | continue; |
152 | } |
153 | TensorProto tp; |
154 | if (!tp.ParseFromString(tensor_proto_strs[tensor_proto_strs_index++])) { |
155 | return errors::Internal("Could not parse TensorProto" ); |
156 | } |
157 | if (!out->at(i).FromProto(tp)) { |
158 | return errors::Internal("Could not parse Tensor" ); |
159 | } |
160 | } |
161 | return OkStatus(); |
162 | } |
163 | |
164 | } // namespace data |
165 | } // namespace tensorflow |
166 | |