1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file base64.h
23 * \brief data stream support to input and output from/to base64 stream
24 * base64 is easier to store and pass as text format in mapreduce
25 */
26#ifndef TVM_SUPPORT_BASE64_H_
27#define TVM_SUPPORT_BASE64_H_
28
29#include <tvm/runtime/logging.h>
30
31#include <cctype>
32#include <cstdio>
33#include <string>
34
35namespace tvm {
36namespace support {
37/*! \brief namespace of base64 decoding and encoding table */
38namespace base64 {
39// decoding table
40const char DecodeTable[] = {
41 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
42 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
43 62, // '+'
44 0, 0, 0,
45 63, // '/'
46 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
47 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
48 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
49 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
50 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
51};
52// encoding table
53static const char EncodeTable[] =
54 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
55} // namespace base64
56
57/*!
58 * \brief Buffer reader from stream to avoid
59 * virtual call overhead on each read.
60 */
61class StreamBufferReader {
62 public:
63 explicit StreamBufferReader(size_t buffer_size) { buffer_.resize(buffer_size); }
64 /*!
65 * \brief set input stream
66 * \param stream The stream to be set
67 */
68 void set_stream(dmlc::Stream* stream) {
69 stream_ = stream;
70 read_len_ = read_ptr_ = 1;
71 }
72 /*!
73 * \return allows quick read using get char
74 */
75 int GetChar() {
76 while (true) {
77 if (read_ptr_ < read_len_) {
78 return static_cast<int>(buffer_[read_ptr_++]);
79 } else {
80 read_len_ = stream_->Read(&buffer_[0], buffer_.length());
81 if (read_len_ == 0) return EOF;
82 read_ptr_ = 0;
83 }
84 }
85 }
86 /*! \return whether we are reaching the end of file */
87 bool AtEnd() const { return read_len_ == 0; }
88
89 private:
90 /*! \brief the underlying stream */
91 dmlc::Stream* stream_{nullptr};
92 /*! \brief buffer to hold data */
93 std::string buffer_;
94 /*! \brief length of valid data in buffer */
95 size_t read_len_{1};
96 /*! \brief pointer in the buffer */
97 size_t read_ptr_{1};
98};
99
100/*!
101 * \brief Input stream from base64 encoding
102 */
103class Base64InStream : public dmlc::Stream {
104 public:
105 explicit Base64InStream(dmlc::Stream* fs) : reader_(256) { reader_.set_stream(fs); }
106 /*!
107 * \brief initialize the stream position to beginning of next base64 stream
108 * \note call this function before actually start read
109 */
110 void InitPosition(void) {
111 // get a character
112 do {
113 temp_ch_ = reader_.GetChar();
114 } while (isspace(temp_ch_));
115 }
116 /*! \brief whether current position is end of a base64 stream */
117 bool IsEOF(void) const { return num_prev_ == 0 && (temp_ch_ == EOF || isspace(temp_ch_)); }
118
119 using dmlc::Stream::Read;
120 // override read function.
121 size_t Read(void* ptr, size_t size) final {
122 using base64::DecodeTable;
123 if (size == 0) return 0;
124 // use tlen to record left size
125 size_t tlen = size;
126 unsigned char* cptr = static_cast<unsigned char*>(ptr);
127 // if anything left, load from previous buffered result
128 if (num_prev_ != 0) {
129 if (num_prev_ == 2) {
130 if (tlen >= 2) {
131 *cptr++ = buf_prev[0];
132 *cptr++ = buf_prev[1];
133 tlen -= 2;
134 num_prev_ = 0;
135 } else {
136 // assert tlen == 1
137 *cptr++ = buf_prev[0];
138 --tlen;
139 buf_prev[0] = buf_prev[1];
140 num_prev_ = 1;
141 }
142 } else {
143 // assert num_prev_ == 1
144 *cptr++ = buf_prev[0];
145 --tlen;
146 num_prev_ = 0;
147 }
148 }
149 if (tlen == 0) return size;
150 int nvalue;
151 // note: everything goes with 4 bytes in Base64
152 // so we process 4 bytes a unit
153 while (tlen && temp_ch_ != EOF && !isspace(temp_ch_)) {
154 // first byte
155 nvalue = DecodeTable[temp_ch_] << 18;
156 {
157 // second byte
158 temp_ch_ = reader_.GetChar();
159 ICHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format";
160 nvalue |= DecodeTable[temp_ch_] << 12;
161 *cptr++ = (nvalue >> 16) & 0xFF;
162 --tlen;
163 }
164 {
165 // third byte
166 temp_ch_ = reader_.GetChar();
167 ICHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format";
168 // handle termination
169 if (temp_ch_ == '=') {
170 temp_ch_ = reader_.GetChar();
171 ICHECK(temp_ch_ == '=') << "invalid base64 format";
172 temp_ch_ = reader_.GetChar();
173 ICHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format";
174 break;
175 }
176 nvalue |= DecodeTable[temp_ch_] << 6;
177 if (tlen) {
178 *cptr++ = (nvalue >> 8) & 0xFF;
179 --tlen;
180 } else {
181 buf_prev[num_prev_++] = (nvalue >> 8) & 0xFF;
182 }
183 }
184 {
185 // fourth byte
186 temp_ch_ = reader_.GetChar();
187 ICHECK(temp_ch_ != EOF && !isspace(temp_ch_)) << "invalid base64 format";
188 if (temp_ch_ == '=') {
189 temp_ch_ = reader_.GetChar();
190 ICHECK(temp_ch_ == EOF || isspace(temp_ch_)) << "invalid base64 format";
191 break;
192 }
193 nvalue |= DecodeTable[temp_ch_];
194 if (tlen) {
195 *cptr++ = nvalue & 0xFF;
196 --tlen;
197 } else {
198 buf_prev[num_prev_++] = nvalue & 0xFF;
199 }
200 }
201 // get next char
202 temp_ch_ = reader_.GetChar();
203 }
204 if (kStrictCheck) {
205 ICHECK_EQ(tlen, 0) << "Base64InStream: read incomplete";
206 }
207 return size - tlen;
208 }
209 virtual void Write(const void* ptr, size_t size) {
210 LOG(FATAL) << "Base64InStream do not support write";
211 }
212
213 private:
214 // internal reader
215 StreamBufferReader reader_;
216 int temp_ch_{0};
217 int num_prev_{0};
218 unsigned char buf_prev[2];
219 // whether we need to do strict check
220 static const bool kStrictCheck = false;
221};
222
223/*!
224 * \brief Stream to write to base64 format.
225 */
226class Base64OutStream : public dmlc::Stream {
227 public:
228 explicit Base64OutStream(dmlc::Stream* fp) : fp_(fp) {}
229
230 using dmlc::Stream::Write;
231
232 void Write(const void* ptr, size_t size) final {
233 using base64::EncodeTable;
234 size_t tlen = size;
235 const unsigned char* cptr = static_cast<const unsigned char*>(ptr);
236 while (tlen) {
237 while (buf__top_ < 3 && tlen != 0) {
238 buf_[++buf__top_] = *cptr++;
239 --tlen;
240 }
241 if (buf__top_ == 3) {
242 // flush 4 bytes out
243 PutChar(EncodeTable[buf_[1] >> 2]);
244 PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]);
245 PutChar(EncodeTable[((buf_[2] << 2) | (buf_[3] >> 6)) & 0x3F]);
246 PutChar(EncodeTable[buf_[3] & 0x3F]);
247 buf__top_ = 0;
248 }
249 }
250 }
251 virtual size_t Read(void* ptr, size_t size) {
252 LOG(FATAL) << "Base64OutStream do not support read";
253 }
254 /*!
255 * \brief finish writing of all current base64 stream, do some post processing
256 * \param endch character to put to end of stream, if it is EOF, then nothing will be appended.
257 */
258 void Finish(int endch = EOF) {
259 using base64::EncodeTable;
260 if (buf__top_ == 1) {
261 PutChar(EncodeTable[buf_[1] >> 2]);
262 PutChar(EncodeTable[(buf_[1] << 4) & 0x3F]);
263 PutChar('=');
264 PutChar('=');
265 }
266 if (buf__top_ == 2) {
267 PutChar(EncodeTable[buf_[1] >> 2]);
268 PutChar(EncodeTable[((buf_[1] << 4) | (buf_[2] >> 4)) & 0x3F]);
269 PutChar(EncodeTable[(buf_[2] << 2) & 0x3F]);
270 PutChar('=');
271 }
272 buf__top_ = 0;
273 if (endch != EOF) PutChar(endch);
274 this->Flush();
275 }
276
277 private:
278 static constexpr size_t kBufferSize = 256;
279
280 dmlc::Stream* fp_{nullptr};
281 int buf__top_{0};
282 unsigned char buf_[4];
283 std::string out_buf_;
284
285 void PutChar(char ch) {
286 out_buf_ += ch;
287 if (out_buf_.length() >= kBufferSize) Flush();
288 }
289 void Flush(void) {
290 if (out_buf_.length() != 0) {
291 fp_->Write(&out_buf_[0], out_buf_.length());
292 out_buf_.clear();
293 }
294 }
295};
296} // namespace support
297} // namespace tvm
298#endif // TVM_SUPPORT_BASE64_H_
299