1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file base64.h |
23 | * \brief data stream support to input and output from/to base64 stream |
24 | * base64 is easier to store and pass as text format in mapreduce |
25 | */ |
26 | #ifndef TVM_SUPPORT_BASE64_H_ |
27 | #define TVM_SUPPORT_BASE64_H_ |
28 | |
29 | #include <tvm/runtime/logging.h> |
30 | |
31 | #include <cctype> |
32 | #include <cstdio> |
33 | #include <string> |
34 | |
35 | namespace tvm { |
36 | namespace support { |
37 | /*! \brief namespace of base64 decoding and encoding table */ |
38 | namespace base64 { |
39 | // decoding table |
40 | const char DecodeTable[] = { |
41 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
42 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, |
43 | 62, // '+' |
44 | 0, 0, 0, |
45 | 63, // '/' |
46 | 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' |
47 | 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, |
48 | 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' |
49 | 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, |
50 | 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' |
51 | }; |
52 | // encoding table |
53 | static const char EncodeTable[] = |
54 | "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" ; |
55 | } // namespace base64 |
56 | |
57 | /*! |
58 | * \brief Buffer reader from stream to avoid |
59 | * virtual call overhead on each read. |
60 | */ |
61 | class StreamBufferReader { |
62 | public: |
63 | explicit StreamBufferReader(size_t buffer_size) { buffer_.resize(buffer_size); } |
64 | /*! |
65 | * \brief set input stream |
66 | * \param stream The stream to be set |
67 | */ |
68 | void set_stream(dmlc::Stream* stream) { |
69 | stream_ = stream; |
70 | read_len_ = read_ptr_ = 1; |
71 | } |
72 | /*! |
73 | * \return allows quick read using get char |
74 | */ |
75 | int GetChar() { |
76 | while (true) { |
77 | if (read_ptr_ < read_len_) { |
78 | return static_cast<int>(buffer_[read_ptr_++]); |
79 | } else { |
80 | read_len_ = stream_->Read(&buffer_[0], buffer_.length()); |
81 | if (read_len_ == 0) return EOF; |
82 | read_ptr_ = 0; |
83 | } |
84 | } |
85 | } |
86 | /*! \return whether we are reaching the end of file */ |
87 | bool AtEnd() const { return read_len_ == 0; } |
88 | |
89 | private: |
90 | /*! \brief the underlying stream */ |
91 | dmlc::Stream* stream_{nullptr}; |
92 | /*! \brief buffer to hold data */ |
93 | std::string buffer_; |
94 | /*! \brief length of valid data in buffer */ |
95 | size_t read_len_{1}; |
96 | /*! \brief pointer in the buffer */ |
97 | size_t read_ptr_{1}; |
98 | }; |
99 | |
100 | /*! |
101 | * \brief Input stream from base64 encoding |
102 | */ |
103 | class Base64InStream : public dmlc::Stream { |
104 | public: |
105 | explicit Base64InStream(dmlc::Stream* fs) : reader_(256) { reader_.set_stream(fs); } |
106 | /*! |
107 | * \brief initialize the stream position to beginning of next base64 stream |
108 | * \note call this function before actually start read |
109 | */ |
110 | void InitPosition(void) { |
111 | // get a character |
112 | do { |
113 | temp_ch_ = reader_.GetChar(); |
114 | } while (isspace(temp_ch_)); |
115 | } |
116 | /*! \brief whether current position is end of a base64 stream */ |
117 | bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); } |
118 | |
119 | using dmlc::Stream::Read; |
120 | // override read function. |
121 | size_t Read(void* ptr, size_t size) final { |
122 | using base64::DecodeTable; |
123 | if (size == 0) return 0; |
124 | // use tlen to record left size |
125 | size_t tlen = size; |
126 | unsigned char* cptr = static_cast<unsigned char*>(ptr); |
127 | // if anything left, load from previous buffered result |
128 | if (num_prev_ != 0) { |
129 | if (num_prev_ == 2) { |
130 | if (tlen >= 2) { |
131 | *cptr++ = buf_prev[0]; |
132 | *cptr++ = buf_prev[1]; |
133 | tlen -= 2; |
134 | num_prev_ = 0; |
135 | } else { |
136 | // assert tlen == 1 |
137 | *cptr++ = buf_prev[0]; |
138 | --tlen; |
139 | buf_prev[0] = buf_prev[1]; |
140 | num_prev_ = 1; |
141 | } |
142 | } else { |
143 | // assert num_prev_ == 1 |
144 | *cptr++ = buf_prev[0]; |
145 | --tlen; |
146 | num_prev_ = 0; |
147 | } |
148 | } |
149 | if (tlen == 0) return size; |
150 | int nvalue; |
151 | // note: everything goes with 4 bytes in Base64 |
152 | // so we process 4 bytes a unit |
153 | while (tlen && temp_ch_ != EOF && !isspace(temp_ch_)) { |
154 | // first byte |
155 | nvalue = DecodeTable[temp_ch_] << 18; |
156 | { |
157 | // second byte |
158 | temp_ch_ = reader_.GetChar(); |
159 | ICHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format" ; |
160 | nvalue |= DecodeTable[temp_ch_] << 12; |
161 | *cptr++ = (nvalue >> 16) & 0xFF; |
162 | --tlen; |
163 | } |
164 | { |
165 | // third byte |
166 | temp_ch_ = reader_.GetChar(); |
167 | ICHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format" ; |
168 | // handle termination |
169 | if (temp_ch_ == '=') { |
170 | temp_ch_ = reader_.GetChar(); |
171 | ICHECK(temp_ch_ == '=') << "invalid base64 format" ; |
172 | temp_ch_ = reader_.GetChar(); |
173 | ICHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format" ; |
174 | break; |
175 | } |
176 | nvalue |= DecodeTable[temp_ch_] << 6; |
177 | if (tlen) { |
178 | *cptr++ = (nvalue >> 8) & 0xFF; |
179 | --tlen; |
180 | } else { |
181 | buf_prev[num_prev_++] = (nvalue >> 8) & 0xFF; |
182 | } |
183 | } |
184 | { |
185 | // fourth byte |
186 | temp_ch_ = reader_.GetChar(); |
187 | ICHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format" ; |
188 | if (temp_ch_ == '=') { |
189 | temp_ch_ = reader_.GetChar(); |
190 | ICHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format" ; |
191 | break; |
192 | } |
193 | nvalue |= DecodeTable[temp_ch_]; |
194 | if (tlen) { |
195 | *cptr++ = nvalue & 0xFF; |
196 | --tlen; |
197 | } else { |
198 | buf_prev[num_prev_++] = nvalue & 0xFF; |
199 | } |
200 | } |
201 | // get next char |
202 | temp_ch_ = reader_.GetChar(); |
203 | } |
204 | if (kStrictCheck) { |
205 | ICHECK_EQ(tlen, 0) << "Base64InStream: read incomplete" ; |
206 | } |
207 | return size - tlen; |
208 | } |
209 | virtual void Write(const void* ptr, size_t size) { |
210 | LOG(FATAL) << "Base64InStream do not support write" ; |
211 | } |
212 | |
213 | private: |
214 | // internal reader |
215 | StreamBufferReader reader_; |
216 | int temp_ch_{0}; |
217 | int num_prev_{0}; |
218 | unsigned char buf_prev[2]; |
219 | // whether we need to do strict check |
220 | static const bool kStrictCheck = false; |
221 | }; |
222 | |
223 | /*! |
224 | * \brief Stream to write to base64 format. |
225 | */ |
226 | class Base64OutStream : public dmlc::Stream { |
227 | public: |
228 | explicit Base64OutStream(dmlc::Stream* fp) : fp_(fp) {} |
229 | |
230 | using dmlc::Stream::Write; |
231 | |
232 | void Write(const void* ptr, size_t size) final { |
233 | using base64::EncodeTable; |
234 | size_t tlen = size; |
235 | const unsigned char* cptr = static_cast<const unsigned char*>(ptr); |
236 | while (tlen) { |
237 | while (buf__top_ < 3 && tlen != 0) { |
238 | buf_[++buf__top_] = *cptr++; |
239 | --tlen; |
240 | } |
241 | if (buf__top_ == 3) { |
242 | // flush 4 bytes out |
243 | PutChar(EncodeTable[buf_[1] >> 2]); |
244 | PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]); |
245 | PutChar(EncodeTable[((buf_[2] << 2) | (buf_[3] >> 6)) & 0x3F]); |
246 | PutChar(EncodeTable[buf_[3] & 0x3F]); |
247 | buf__top_ = 0; |
248 | } |
249 | } |
250 | } |
251 | virtual size_t Read(void* ptr, size_t size) { |
252 | LOG(FATAL) << "Base64OutStream do not support read" ; |
253 | } |
254 | /*! |
255 | * \brief finish writing of all current base64 stream, do some post processing |
256 | * \param endch character to put to end of stream, if it is EOF, then nothing will be appended. |
257 | */ |
258 | void Finish(int endch = EOF) { |
259 | using base64::EncodeTable; |
260 | if (buf__top_ == 1) { |
261 | PutChar(EncodeTable[buf_[1] >> 2]); |
262 | PutChar(EncodeTable[(buf_[1] << 4) & 0x3F]); |
263 | PutChar('='); |
264 | PutChar('='); |
265 | } |
266 | if (buf__top_ == 2) { |
267 | PutChar(EncodeTable[buf_[1] >> 2]); |
268 | PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]); |
269 | PutChar(EncodeTable[(buf_[2] << 2) & 0x3F]); |
270 | PutChar('='); |
271 | } |
272 | buf__top_ = 0; |
273 | if (endch != EOF) PutChar(endch); |
274 | this->Flush(); |
275 | } |
276 | |
277 | private: |
278 | static constexpr size_t kBufferSize = 256; |
279 | |
280 | dmlc::Stream* fp_{nullptr}; |
281 | int buf__top_{0}; |
282 | unsigned char buf_[4]; |
283 | std::string out_buf_; |
284 | |
285 | void PutChar(char ch) { |
286 | out_buf_ += ch; |
287 | if (out_buf_.length() >= kBufferSize) Flush(); |
288 | } |
289 | void Flush(void) { |
290 | if (out_buf_.length() != 0) { |
291 | fp_->Write(&out_buf_[0], out_buf_.length()); |
292 | out_buf_.clear(); |
293 | } |
294 | } |
295 | }; |
296 | } // namespace support |
297 | } // namespace tvm |
298 | #endif // TVM_SUPPORT_BASE64_H_ |
299 | |