1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/platform/tensor_coding.h" |
17 | |
18 | #include <vector> |
19 | |
20 | #include "tensorflow/core/platform/coding.h" |
21 | #include "tensorflow/core/platform/protobuf.h" |
22 | #include "tensorflow/core/platform/strcat.h" |
23 | #include "tensorflow/core/platform/stringpiece.h" |
24 | |
25 | #if defined(TENSORFLOW_PROTOBUF_USES_CORD) |
26 | #include "strings/cord_varint.h" |
27 | #endif // defined(TENSORFLOW_PROTOBUF_USES_CORD) |
28 | |
29 | namespace tensorflow { |
30 | namespace port { |
31 | |
32 | void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out) { |
33 | out->assign(src.data(), src.size()); |
34 | } |
35 | |
36 | void EncodeStringList(const tstring* strings, int64_t n, string* out) { |
37 | out->clear(); |
38 | for (int i = 0; i < n; ++i) { |
39 | core::PutVarint32(out, strings[i].size()); |
40 | } |
41 | for (int i = 0; i < n; ++i) { |
42 | out->append(strings[i]); |
43 | } |
44 | } |
45 | |
46 | bool DecodeStringList(const string& src, tstring* strings, int64_t n) { |
47 | std::vector<uint32> sizes(n); |
48 | StringPiece reader(src); |
49 | int64_t tot = 0; |
50 | for (auto& v : sizes) { |
51 | if (!core::GetVarint32(&reader, &v)) return false; |
52 | tot += v; |
53 | } |
54 | if (tot != static_cast<int64_t>(reader.size())) { |
55 | return false; |
56 | } |
57 | |
58 | tstring* data = strings; |
59 | for (int64_t i = 0; i < n; ++i, ++data) { |
60 | auto size = sizes[i]; |
61 | if (size > reader.size()) { |
62 | return false; |
63 | } |
64 | data->assign(reader.data(), size); |
65 | reader.remove_prefix(size); |
66 | } |
67 | |
68 | return true; |
69 | } |
70 | |
71 | void CopyFromArray(string* s, const char* base, size_t bytes) { |
72 | s->assign(base, bytes); |
73 | } |
74 | |
75 | class StringListEncoderImpl : public StringListEncoder { |
76 | public: |
77 | explicit StringListEncoderImpl(string* out) : out_(out) {} |
78 | ~StringListEncoderImpl() override = default; |
79 | |
80 | void Append(const protobuf::MessageLite& m) override { |
81 | core::PutVarint32(out_, m.ByteSizeLong()); |
82 | tensorflow::string serialized_message; |
83 | m.AppendToString(&serialized_message); |
84 | strings::StrAppend(&rest_, serialized_message); |
85 | } |
86 | |
87 | void Append(const string& s) override { |
88 | core::PutVarint32(out_, s.length()); |
89 | strings::StrAppend(&rest_, s); |
90 | } |
91 | |
92 | void Finalize() override { strings::StrAppend(out_, rest_); } |
93 | |
94 | private: |
95 | string* out_; |
96 | string rest_; |
97 | }; |
98 | |
99 | class StringListDecoderImpl : public StringListDecoder { |
100 | public: |
101 | explicit StringListDecoderImpl(const string& in) : reader_(in) {} |
102 | ~StringListDecoderImpl() override = default; |
103 | |
104 | bool ReadSizes(std::vector<uint32>* sizes) override { |
105 | int64_t total = 0; |
106 | for (auto& size : *sizes) { |
107 | if (!core::GetVarint32(&reader_, &size)) return false; |
108 | total += size; |
109 | } |
110 | if (total != static_cast<int64_t>(reader_.size())) { |
111 | return false; |
112 | } |
113 | return true; |
114 | } |
115 | |
116 | const char* Data(uint32 size) override { |
117 | const char* data = reader_.data(); |
118 | reader_.remove_prefix(size); |
119 | return data; |
120 | } |
121 | |
122 | private: |
123 | StringPiece reader_; |
124 | }; |
125 | |
126 | std::unique_ptr<StringListEncoder> NewStringListEncoder(string* out) { |
127 | return std::unique_ptr<StringListEncoder>(new StringListEncoderImpl(out)); |
128 | } |
129 | |
130 | std::unique_ptr<StringListDecoder> NewStringListDecoder(const string& in) { |
131 | return std::unique_ptr<StringListDecoder>(new StringListDecoderImpl(in)); |
132 | } |
133 | |
134 | #if defined(TENSORFLOW_PROTOBUF_USES_CORD) |
135 | void AssignRefCounted(StringPiece src, core::RefCounted* obj, absl::Cord* out) { |
136 | obj->Ref(); |
137 | *out = absl::MakeCordFromExternal(src, [obj] { obj->Unref(); }); |
138 | } |
139 | |
140 | void EncodeStringList(const tstring* strings, int64_t n, absl::Cord* out) { |
141 | out->Clear(); |
142 | for (int i = 0; i < n; ++i) { |
143 | ::strings::CordAppendVarint(strings[i].size(), out); |
144 | } |
145 | for (int i = 0; i < n; ++i) { |
146 | out->Append(strings[i]); |
147 | } |
148 | } |
149 | |
150 | bool DecodeStringList(const absl::Cord& src, string* strings, int64_t n) { |
151 | std::vector<uint32> sizes(n); |
152 | CordReader reader(src); |
153 | int64_t tot = 0; |
154 | for (auto& v : sizes) { |
155 | if (!::strings::CordReaderReadVarint(&reader, &v)) return false; |
156 | tot += v; |
157 | } |
158 | if (tot != reader.Available()) { |
159 | return false; |
160 | } |
161 | string* data = strings; |
162 | for (int i = 0; i < n; ++i, ++data) { |
163 | auto size = sizes[i]; |
164 | if (size > reader.Available()) { |
165 | return false; |
166 | } |
167 | gtl::STLStringResizeUninitialized(data, size); |
168 | reader.ReadN(size, gtl::string_as_array(data)); |
169 | } |
170 | return true; |
171 | } |
172 | |
173 | bool DecodeStringList(const absl::Cord& src, tstring* strings, int64_t n) { |
174 | std::vector<uint32> sizes(n); |
175 | CordReader reader(src); |
176 | int64_t tot = 0; |
177 | for (auto& v : sizes) { |
178 | if (!::strings::CordReaderReadVarint(&reader, &v)) return false; |
179 | tot += v; |
180 | } |
181 | if (tot != reader.Available()) { |
182 | return false; |
183 | } |
184 | tstring* data = strings; |
185 | for (int i = 0; i < n; ++i, ++data) { |
186 | auto size = sizes[i]; |
187 | if (size > reader.Available()) { |
188 | return false; |
189 | } |
190 | data->resize_uninitialized(size); |
191 | reader.ReadN(size, data->data()); |
192 | } |
193 | return true; |
194 | } |
195 | |
196 | void CopyFromArray(absl::Cord* c, const char* base, size_t bytes) { |
197 | c->CopyFrom(base, bytes); |
198 | } |
199 | |
200 | class CordStringListEncoderImpl : public StringListEncoder { |
201 | public: |
202 | explicit CordStringListEncoderImpl(absl::Cord* out) : out_(out) {} |
203 | ~CordStringListEncoderImpl() override = default; |
204 | |
205 | void Append(const protobuf::MessageLite& m) override { |
206 | ::strings::CordAppendVarint(m.ByteSizeLong(), out_); |
207 | m.AppendToString(&rest_); |
208 | } |
209 | |
210 | void Append(const string& s) override { |
211 | ::strings::CordAppendVarint(s.length(), out_); |
212 | rest_.append(s.data(), s.size()); |
213 | } |
214 | |
215 | void Finalize() override { out_->Append(rest_); } |
216 | |
217 | private: |
218 | absl::Cord* out_; |
219 | string rest_; |
220 | }; |
221 | |
222 | class CordStringListDecoderImpl : public StringListDecoder { |
223 | public: |
224 | explicit CordStringListDecoderImpl(const absl::Cord& in) : reader_(in) {} |
225 | ~CordStringListDecoderImpl() override = default; |
226 | |
227 | bool ReadSizes(std::vector<uint32>* sizes) override { |
228 | int64_t total = 0; |
229 | for (auto& size : *sizes) { |
230 | if (!::strings::CordReaderReadVarint(&reader_, &size)) return false; |
231 | total += size; |
232 | } |
233 | if (total != static_cast<int64_t>(reader_.Available())) { |
234 | return false; |
235 | } |
236 | return true; |
237 | } |
238 | |
239 | const char* Data(uint32 size) override { |
240 | tmp_.resize(size); |
241 | reader_.ReadN(size, tmp_.data()); |
242 | return tmp_.data(); |
243 | } |
244 | |
245 | private: |
246 | CordReader reader_; |
247 | std::vector<char> tmp_; |
248 | }; |
249 | |
250 | std::unique_ptr<StringListEncoder> NewStringListEncoder(absl::Cord* out) { |
251 | return std::unique_ptr<StringListEncoder>(new CordStringListEncoderImpl(out)); |
252 | } |
253 | |
254 | std::unique_ptr<StringListDecoder> NewStringListDecoder(const absl::Cord& in) { |
255 | return std::unique_ptr<StringListDecoder>(new CordStringListDecoderImpl(in)); |
256 | } |
257 | |
258 | #endif // defined(TENSORFLOW_PROTOBUF_USES_CORD) |
259 | |
260 | } // namespace port |
261 | } // namespace tensorflow |
262 | |