1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/tsl/platform/base64.h"
17
18#include <cstring>
19#include <memory>
20
21#include "tensorflow/tsl/platform/errors.h"
22#include "tensorflow/tsl/platform/stringpiece.h"
23
24namespace tsl {
25namespace {
26// This array must have signed type.
27// clang-format off
28constexpr int8 kBase64Bytes[128] = {
29 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
30 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
31 -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
32 -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x3E, -1, -1,
33 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, -1, -1,
34 -1, -1, -1, -1, -1, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
35 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12,
36 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, -1, -1, -1, -1, 0x3F,
37 -1, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24,
38 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30,
39 0x31, 0x32, 0x33, -1, -1, -1, -1, -1};
40// clang-format on
41
42constexpr char kBase64UrlSafeChars[65] =
43 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
44
45constexpr char kPadChar = '=';
46
47// Converts a char (8 bits) into a 6-bit value for decoding. If the input char
48// is invalid for base64 encoding, the return value has at least its upper 25
49// bits set.
50inline uint32 Convert(char x) {
51 // If x < 128, then we look up x in the table. If x is valid, then the table
52 // will have a value <= 0x3F, otherwise the table will have -1. If x >= 128,
53 // we still do some table lookup, but the value is ignored since we explicitly
54 // set the high bit of y to 1. Either way, y is negative (high bit set) in
55 // case of error.
56 const int8_t y = kBase64Bytes[x & 0x7F] | (x & 0x80);
57 // Casting from int8 to int32 preserves sign by sign extension. If y was
58 // negative, at least its 25 high bits of the return value are set.
59 const int32_t z = static_cast<int32>(y);
60 return static_cast<uint32>(z);
61}
62
63Status DecodeThreeChars(const char* codes, char* result) {
64 const uint32 packed = (Convert(codes[0]) << 18) | (Convert(codes[1]) << 12) |
65 (Convert(codes[2]) << 6) | (Convert(codes[3]));
66 // Convert() return value has upper 25 bits set if input is invalid.
67 // Therefore `packed` has high bits set iff at least one of code is invalid.
68 if (TF_PREDICT_FALSE((packed & 0xFF000000) != 0)) {
69 return errors::InvalidArgument("Invalid character found in base64.");
70 }
71 result[0] = static_cast<char>(packed >> 16);
72 result[1] = static_cast<char>(packed >> 8);
73 result[2] = static_cast<char>(packed);
74 return OkStatus();
75}
76} // namespace
77
78template <typename T>
79Status Base64Decode(StringPiece data, T* decoded) {
80 if (decoded == nullptr) {
81 return errors::Internal("'decoded' cannot be nullptr.");
82 }
83
84 if (data.empty()) {
85 decoded->clear();
86 return OkStatus();
87 }
88
89 // This decoding procedure will write 3 * ceil(data.size() / 4) bytes to be
90 // output buffer, then truncate if necessary. Therefore we must overestimate
91 // and allocate sufficient amount. Currently max_decoded_size may overestimate
92 // by up to 3 bytes.
93 const size_t max_decoded_size = 3 * (data.size() / 4) + 3;
94 std::unique_ptr<char[]> buffer(new char[max_decoded_size]);
95 char* current = buffer.get();
96 if (current == nullptr) {
97 return errors::ResourceExhausted(
98 "Failed to allocate buffer for decoded string.");
99 }
100
101 const char* b64 = data.data();
102 const char* end = data.data() + data.size();
103
104 while (end - b64 > 4) {
105 TF_RETURN_IF_ERROR(DecodeThreeChars(b64, current));
106 b64 += 4;
107 current += 3;
108 }
109
110 if (end - b64 == 4) {
111 // The data length is a multiple of 4. Check for padding.
112 // Base64 cannot have more than 2 paddings.
113 if (b64[2] == kPadChar && b64[3] == kPadChar) {
114 end -= 2;
115 }
116 if (b64[2] != kPadChar && b64[3] == kPadChar) {
117 end -= 1;
118 }
119 }
120
121 const int remain = static_cast<int>(end - b64);
122 if (TF_PREDICT_FALSE(remain == 1)) {
123 // We may check this condition early by checking data.size() % 4 == 1.
124 return errors::InvalidArgument(
125 "Base64 string length cannot be 1 modulo 4.");
126 }
127
128 // A valid base64 character will replace paddings, if any.
129 char tail[4] = {kBase64UrlSafeChars[0], kBase64UrlSafeChars[0],
130 kBase64UrlSafeChars[0], kBase64UrlSafeChars[0]};
131 // Copy tail of the input into the array, then decode.
132 std::memcpy(tail, b64, remain * sizeof(*b64));
133 TF_RETURN_IF_ERROR(DecodeThreeChars(tail, current));
134 // We know how many parsed characters are valid.
135 current += remain - 1;
136
137 decoded->assign(buffer.get(), current - buffer.get());
138 return OkStatus();
139}
140
141template <typename T>
142Status Base64Encode(StringPiece source, T* encoded) {
143 return Base64Encode(source, false, encoded);
144}
145
146template <typename T>
147Status Base64Encode(StringPiece source, bool with_padding, T* encoded) {
148 const char* const base64_chars = kBase64UrlSafeChars;
149 if (encoded == nullptr) {
150 return errors::Internal("'encoded' cannot be nullptr.");
151 }
152
153 // max_encoded_size may overestimate by up to 4 bytes.
154 const size_t max_encoded_size = 4 * (source.size() / 3) + 4;
155 std::unique_ptr<char[]> buffer(new char[max_encoded_size]);
156 char* current = buffer.get();
157 if (current == nullptr) {
158 return errors::ResourceExhausted(
159 "Failed to allocate buffer for encoded string.");
160 }
161
162 const char* data = source.data();
163 const char* const end = source.data() + source.size();
164
165 // Encode each block.
166 while (end - data >= 3) {
167 *current++ = base64_chars[(data[0] >> 2) & 0x3F];
168 *current++ =
169 base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)];
170 *current++ =
171 base64_chars[((data[1] & 0x0F) << 2) | ((data[2] >> 6) & 0x03)];
172 *current++ = base64_chars[data[2] & 0x3F];
173
174 data += 3;
175 }
176
177 // Take care of the tail.
178 if (end - data == 2) {
179 *current++ = base64_chars[(data[0] >> 2) & 0x3F];
180 *current++ =
181 base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)];
182 *current++ = base64_chars[(data[1] & 0x0F) << 2];
183 if (with_padding) {
184 *current++ = kPadChar;
185 }
186 } else if (end - data == 1) {
187 *current++ = base64_chars[(data[0] >> 2) & 0x3F];
188 *current++ = base64_chars[(data[0] & 0x03) << 4];
189 if (with_padding) {
190 *current++ = kPadChar;
191 *current++ = kPadChar;
192 }
193 }
194
195 encoded->assign(buffer.get(), current - buffer.get());
196 return OkStatus();
197}
198
199template Status Base64Decode<std::string>(StringPiece data,
200 std::string* decoded);
201template Status Base64Encode<std::string>(StringPiece source,
202 std::string* encoded);
203template Status Base64Encode<std::string>(StringPiece source, bool with_padding,
204 std::string* encoded);
205
206template Status Base64Decode<tstring>(StringPiece data, tstring* decoded);
207template Status Base64Encode<tstring>(StringPiece source, tstring* encoded);
208template Status Base64Encode<tstring>(StringPiece source, bool with_padding,
209 tstring* encoded);
210
211} // namespace tsl
212