1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/tsl/platform/base64.h" |
17 | |
18 | #include <cstring> |
19 | #include <memory> |
20 | |
21 | #include "tensorflow/tsl/platform/errors.h" |
22 | #include "tensorflow/tsl/platform/stringpiece.h" |
23 | |
24 | namespace tsl { |
25 | namespace { |
26 | // This array must have signed type. |
27 | // clang-format off |
28 | constexpr int8 kBase64Bytes[128] = { |
29 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
30 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
31 | -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, |
32 | -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x3E, -1, -1, |
33 | 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, -1, -1, |
34 | -1, -1, -1, -1, -1, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, |
35 | 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, |
36 | 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, -1, -1, -1, -1, 0x3F, |
37 | -1, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, |
38 | 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, |
39 | 0x31, 0x32, 0x33, -1, -1, -1, -1, -1}; |
40 | // clang-format on |
41 | |
42 | constexpr char kBase64UrlSafeChars[65] = |
43 | "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" ; |
44 | |
45 | constexpr char kPadChar = '='; |
46 | |
47 | // Converts a char (8 bits) into a 6-bit value for decoding. If the input char |
48 | // is invalid for base64 encoding, the return value has at least its upper 25 |
49 | // bits set. |
50 | inline uint32 Convert(char x) { |
51 | // If x < 128, then we look up x in the table. If x is valid, then the table |
52 | // will have a value <= 0x3F, otherwise the table will have -1. If x >= 128, |
53 | // we still do some table lookup, but the value is ignored since we explicitly |
54 | // set the high bit of y to 1. Either way, y is negative (high bit set) in |
55 | // case of error. |
56 | const int8_t y = kBase64Bytes[x & 0x7F] | (x & 0x80); |
57 | // Casting from int8 to int32 preserves sign by sign extension. If y was |
58 | // negative, at least its 25 high bits of the return value are set. |
59 | const int32_t z = static_cast<int32>(y); |
60 | return static_cast<uint32>(z); |
61 | } |
62 | |
63 | Status DecodeThreeChars(const char* codes, char* result) { |
64 | const uint32 packed = (Convert(codes[0]) << 18) | (Convert(codes[1]) << 12) | |
65 | (Convert(codes[2]) << 6) | (Convert(codes[3])); |
66 | // Convert() return value has upper 25 bits set if input is invalid. |
67 | // Therefore `packed` has high bits set iff at least one of code is invalid. |
68 | if (TF_PREDICT_FALSE((packed & 0xFF000000) != 0)) { |
69 | return errors::InvalidArgument("Invalid character found in base64." ); |
70 | } |
71 | result[0] = static_cast<char>(packed >> 16); |
72 | result[1] = static_cast<char>(packed >> 8); |
73 | result[2] = static_cast<char>(packed); |
74 | return OkStatus(); |
75 | } |
76 | } // namespace |
77 | |
78 | template <typename T> |
79 | Status Base64Decode(StringPiece data, T* decoded) { |
80 | if (decoded == nullptr) { |
81 | return errors::Internal("'decoded' cannot be nullptr." ); |
82 | } |
83 | |
84 | if (data.empty()) { |
85 | decoded->clear(); |
86 | return OkStatus(); |
87 | } |
88 | |
89 | // This decoding procedure will write 3 * ceil(data.size() / 4) bytes to be |
90 | // output buffer, then truncate if necessary. Therefore we must overestimate |
91 | // and allocate sufficient amount. Currently max_decoded_size may overestimate |
92 | // by up to 3 bytes. |
93 | const size_t max_decoded_size = 3 * (data.size() / 4) + 3; |
94 | std::unique_ptr<char[]> buffer(new char[max_decoded_size]); |
95 | char* current = buffer.get(); |
96 | if (current == nullptr) { |
97 | return errors::ResourceExhausted( |
98 | "Failed to allocate buffer for decoded string." ); |
99 | } |
100 | |
101 | const char* b64 = data.data(); |
102 | const char* end = data.data() + data.size(); |
103 | |
104 | while (end - b64 > 4) { |
105 | TF_RETURN_IF_ERROR(DecodeThreeChars(b64, current)); |
106 | b64 += 4; |
107 | current += 3; |
108 | } |
109 | |
110 | if (end - b64 == 4) { |
111 | // The data length is a multiple of 4. Check for padding. |
112 | // Base64 cannot have more than 2 paddings. |
113 | if (b64[2] == kPadChar && b64[3] == kPadChar) { |
114 | end -= 2; |
115 | } |
116 | if (b64[2] != kPadChar && b64[3] == kPadChar) { |
117 | end -= 1; |
118 | } |
119 | } |
120 | |
121 | const int remain = static_cast<int>(end - b64); |
122 | if (TF_PREDICT_FALSE(remain == 1)) { |
123 | // We may check this condition early by checking data.size() % 4 == 1. |
124 | return errors::InvalidArgument( |
125 | "Base64 string length cannot be 1 modulo 4." ); |
126 | } |
127 | |
128 | // A valid base64 character will replace paddings, if any. |
129 | char tail[4] = {kBase64UrlSafeChars[0], kBase64UrlSafeChars[0], |
130 | kBase64UrlSafeChars[0], kBase64UrlSafeChars[0]}; |
131 | // Copy tail of the input into the array, then decode. |
132 | std::memcpy(tail, b64, remain * sizeof(*b64)); |
133 | TF_RETURN_IF_ERROR(DecodeThreeChars(tail, current)); |
134 | // We know how many parsed characters are valid. |
135 | current += remain - 1; |
136 | |
137 | decoded->assign(buffer.get(), current - buffer.get()); |
138 | return OkStatus(); |
139 | } |
140 | |
141 | template <typename T> |
142 | Status Base64Encode(StringPiece source, T* encoded) { |
143 | return Base64Encode(source, false, encoded); |
144 | } |
145 | |
146 | template <typename T> |
147 | Status Base64Encode(StringPiece source, bool with_padding, T* encoded) { |
148 | const char* const base64_chars = kBase64UrlSafeChars; |
149 | if (encoded == nullptr) { |
150 | return errors::Internal("'encoded' cannot be nullptr." ); |
151 | } |
152 | |
153 | // max_encoded_size may overestimate by up to 4 bytes. |
154 | const size_t max_encoded_size = 4 * (source.size() / 3) + 4; |
155 | std::unique_ptr<char[]> buffer(new char[max_encoded_size]); |
156 | char* current = buffer.get(); |
157 | if (current == nullptr) { |
158 | return errors::ResourceExhausted( |
159 | "Failed to allocate buffer for encoded string." ); |
160 | } |
161 | |
162 | const char* data = source.data(); |
163 | const char* const end = source.data() + source.size(); |
164 | |
165 | // Encode each block. |
166 | while (end - data >= 3) { |
167 | *current++ = base64_chars[(data[0] >> 2) & 0x3F]; |
168 | *current++ = |
169 | base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)]; |
170 | *current++ = |
171 | base64_chars[((data[1] & 0x0F) << 2) | ((data[2] >> 6) & 0x03)]; |
172 | *current++ = base64_chars[data[2] & 0x3F]; |
173 | |
174 | data += 3; |
175 | } |
176 | |
177 | // Take care of the tail. |
178 | if (end - data == 2) { |
179 | *current++ = base64_chars[(data[0] >> 2) & 0x3F]; |
180 | *current++ = |
181 | base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)]; |
182 | *current++ = base64_chars[(data[1] & 0x0F) << 2]; |
183 | if (with_padding) { |
184 | *current++ = kPadChar; |
185 | } |
186 | } else if (end - data == 1) { |
187 | *current++ = base64_chars[(data[0] >> 2) & 0x3F]; |
188 | *current++ = base64_chars[(data[0] & 0x03) << 4]; |
189 | if (with_padding) { |
190 | *current++ = kPadChar; |
191 | *current++ = kPadChar; |
192 | } |
193 | } |
194 | |
195 | encoded->assign(buffer.get(), current - buffer.get()); |
196 | return OkStatus(); |
197 | } |
198 | |
199 | template Status Base64Decode<std::string>(StringPiece data, |
200 | std::string* decoded); |
201 | template Status Base64Encode<std::string>(StringPiece source, |
202 | std::string* encoded); |
203 | template Status Base64Encode<std::string>(StringPiece source, bool with_padding, |
204 | std::string* encoded); |
205 | |
206 | template Status Base64Decode<tstring>(StringPiece data, tstring* decoded); |
207 | template Status Base64Encode<tstring>(StringPiece source, tstring* encoded); |
208 | template Status Base64Encode<tstring>(StringPiece source, bool with_padding, |
209 | tstring* encoded); |
210 | |
211 | } // namespace tsl |
212 | |