1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/platform/tensor_coding.h"
17
18#include <vector>
19
20#include "tensorflow/core/platform/coding.h"
21#include "tensorflow/core/platform/protobuf.h"
22#include "tensorflow/core/platform/strcat.h"
23#include "tensorflow/core/platform/stringpiece.h"
24
25#if defined(TENSORFLOW_PROTOBUF_USES_CORD)
26#include "strings/cord_varint.h"
27#endif // defined(TENSORFLOW_PROTOBUF_USES_CORD)
28
29namespace tensorflow {
30namespace port {
31
32void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out) {
33 out->assign(src.data(), src.size());
34}
35
36void EncodeStringList(const tstring* strings, int64_t n, string* out) {
37 out->clear();
38 for (int i = 0; i < n; ++i) {
39 core::PutVarint32(out, strings[i].size());
40 }
41 for (int i = 0; i < n; ++i) {
42 out->append(strings[i]);
43 }
44}
45
46bool DecodeStringList(const string& src, tstring* strings, int64_t n) {
47 std::vector<uint32> sizes(n);
48 StringPiece reader(src);
49 int64_t tot = 0;
50 for (auto& v : sizes) {
51 if (!core::GetVarint32(&reader, &v)) return false;
52 tot += v;
53 }
54 if (tot != static_cast<int64_t>(reader.size())) {
55 return false;
56 }
57
58 tstring* data = strings;
59 for (int64_t i = 0; i < n; ++i, ++data) {
60 auto size = sizes[i];
61 if (size > reader.size()) {
62 return false;
63 }
64 data->assign(reader.data(), size);
65 reader.remove_prefix(size);
66 }
67
68 return true;
69}
70
71void CopyFromArray(string* s, const char* base, size_t bytes) {
72 s->assign(base, bytes);
73}
74
75class StringListEncoderImpl : public StringListEncoder {
76 public:
77 explicit StringListEncoderImpl(string* out) : out_(out) {}
78 ~StringListEncoderImpl() override = default;
79
80 void Append(const protobuf::MessageLite& m) override {
81 core::PutVarint32(out_, m.ByteSizeLong());
82 tensorflow::string serialized_message;
83 m.AppendToString(&serialized_message);
84 strings::StrAppend(&rest_, serialized_message);
85 }
86
87 void Append(const string& s) override {
88 core::PutVarint32(out_, s.length());
89 strings::StrAppend(&rest_, s);
90 }
91
92 void Finalize() override { strings::StrAppend(out_, rest_); }
93
94 private:
95 string* out_;
96 string rest_;
97};
98
99class StringListDecoderImpl : public StringListDecoder {
100 public:
101 explicit StringListDecoderImpl(const string& in) : reader_(in) {}
102 ~StringListDecoderImpl() override = default;
103
104 bool ReadSizes(std::vector<uint32>* sizes) override {
105 int64_t total = 0;
106 for (auto& size : *sizes) {
107 if (!core::GetVarint32(&reader_, &size)) return false;
108 total += size;
109 }
110 if (total != static_cast<int64_t>(reader_.size())) {
111 return false;
112 }
113 return true;
114 }
115
116 const char* Data(uint32 size) override {
117 const char* data = reader_.data();
118 reader_.remove_prefix(size);
119 return data;
120 }
121
122 private:
123 StringPiece reader_;
124};
125
126std::unique_ptr<StringListEncoder> NewStringListEncoder(string* out) {
127 return std::unique_ptr<StringListEncoder>(new StringListEncoderImpl(out));
128}
129
130std::unique_ptr<StringListDecoder> NewStringListDecoder(const string& in) {
131 return std::unique_ptr<StringListDecoder>(new StringListDecoderImpl(in));
132}
133
134#if defined(TENSORFLOW_PROTOBUF_USES_CORD)
135void AssignRefCounted(StringPiece src, core::RefCounted* obj, absl::Cord* out) {
136 obj->Ref();
137 *out = absl::MakeCordFromExternal(src, [obj] { obj->Unref(); });
138}
139
140void EncodeStringList(const tstring* strings, int64_t n, absl::Cord* out) {
141 out->Clear();
142 for (int i = 0; i < n; ++i) {
143 ::strings::CordAppendVarint(strings[i].size(), out);
144 }
145 for (int i = 0; i < n; ++i) {
146 out->Append(strings[i]);
147 }
148}
149
150bool DecodeStringList(const absl::Cord& src, string* strings, int64_t n) {
151 std::vector<uint32> sizes(n);
152 CordReader reader(src);
153 int64_t tot = 0;
154 for (auto& v : sizes) {
155 if (!::strings::CordReaderReadVarint(&reader, &v)) return false;
156 tot += v;
157 }
158 if (tot != reader.Available()) {
159 return false;
160 }
161 string* data = strings;
162 for (int i = 0; i < n; ++i, ++data) {
163 auto size = sizes[i];
164 if (size > reader.Available()) {
165 return false;
166 }
167 gtl::STLStringResizeUninitialized(data, size);
168 reader.ReadN(size, gtl::string_as_array(data));
169 }
170 return true;
171}
172
173bool DecodeStringList(const absl::Cord& src, tstring* strings, int64_t n) {
174 std::vector<uint32> sizes(n);
175 CordReader reader(src);
176 int64_t tot = 0;
177 for (auto& v : sizes) {
178 if (!::strings::CordReaderReadVarint(&reader, &v)) return false;
179 tot += v;
180 }
181 if (tot != reader.Available()) {
182 return false;
183 }
184 tstring* data = strings;
185 for (int i = 0; i < n; ++i, ++data) {
186 auto size = sizes[i];
187 if (size > reader.Available()) {
188 return false;
189 }
190 data->resize_uninitialized(size);
191 reader.ReadN(size, data->data());
192 }
193 return true;
194}
195
196void CopyFromArray(absl::Cord* c, const char* base, size_t bytes) {
197 c->CopyFrom(base, bytes);
198}
199
200class CordStringListEncoderImpl : public StringListEncoder {
201 public:
202 explicit CordStringListEncoderImpl(absl::Cord* out) : out_(out) {}
203 ~CordStringListEncoderImpl() override = default;
204
205 void Append(const protobuf::MessageLite& m) override {
206 ::strings::CordAppendVarint(m.ByteSizeLong(), out_);
207 m.AppendToString(&rest_);
208 }
209
210 void Append(const string& s) override {
211 ::strings::CordAppendVarint(s.length(), out_);
212 rest_.append(s.data(), s.size());
213 }
214
215 void Finalize() override { out_->Append(rest_); }
216
217 private:
218 absl::Cord* out_;
219 string rest_;
220};
221
222class CordStringListDecoderImpl : public StringListDecoder {
223 public:
224 explicit CordStringListDecoderImpl(const absl::Cord& in) : reader_(in) {}
225 ~CordStringListDecoderImpl() override = default;
226
227 bool ReadSizes(std::vector<uint32>* sizes) override {
228 int64_t total = 0;
229 for (auto& size : *sizes) {
230 if (!::strings::CordReaderReadVarint(&reader_, &size)) return false;
231 total += size;
232 }
233 if (total != static_cast<int64_t>(reader_.Available())) {
234 return false;
235 }
236 return true;
237 }
238
239 const char* Data(uint32 size) override {
240 tmp_.resize(size);
241 reader_.ReadN(size, tmp_.data());
242 return tmp_.data();
243 }
244
245 private:
246 CordReader reader_;
247 std::vector<char> tmp_;
248};
249
250std::unique_ptr<StringListEncoder> NewStringListEncoder(absl::Cord* out) {
251 return std::unique_ptr<StringListEncoder>(new CordStringListEncoderImpl(out));
252}
253
254std::unique_ptr<StringListDecoder> NewStringListDecoder(const absl::Cord& in) {
255 return std::unique_ptr<StringListDecoder>(new CordStringListDecoderImpl(in));
256}
257
258#endif // defined(TENSORFLOW_PROTOBUF_USES_CORD)
259
260} // namespace port
261} // namespace tensorflow
262