1/**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "glow/Support/ZipUtils.h"
18#include "glow/Support/Memory.h"
19
20#include "llvm/ADT/STLExtras.h"
21
22#include <sstream>
23
24namespace glow {
25
26namespace {
27constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
28constexpr uint64_t kFieldAlignment = 64;
29
30static std::string getPadding(size_t cursor, const std::string &filename,
31 size_t size) {
32 size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename.size() +
33 sizeof(mz_uint16) * 2;
34 if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
35 start += sizeof(mz_uint16) * 2;
36 if (size >= MZ_UINT32_MAX) {
37 start += 2 * sizeof(mz_uint64);
38 }
39 if (cursor >= MZ_UINT32_MAX) {
40 start += sizeof(mz_uint64);
41 }
42 }
43 size_t mod = start % kFieldAlignment;
44 size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
45 size_t padding_size = next_offset - start;
46 std::string buf(padding_size + 4, 'Z');
47 // zip extra encoding (key, size_of_extra_bytes)
48 buf[0] = 'F';
49 buf[1] = 'B';
50 buf[2] = (uint8_t)padding_size;
51 buf[3] = (uint8_t)(padding_size >> 8);
52 return buf;
53}
54} // namespace
55
56size_t istreamReadFunc(void *pOpaque, mz_uint64 file_ofs, void *pBuf,
57 size_t n) {
58 auto self = static_cast<ZipReader *>(pOpaque);
59 return self->read(file_ofs, static_cast<char *>(pBuf), n);
60}
61
62ZipReader::ZipReader(const std::string &file_name)
63 : ar_(glow::make_unique<mz_zip_archive>()),
64 in_(glow::make_unique<FileAdapter>(file_name)) {
65 init();
66}
67
68ZipReader::~ZipReader() {
69 mz_zip_reader_end(ar_.get());
70 valid("closing reader for archive ", archive_name_.c_str());
71}
72
73void ZipReader::init() {
74 assert(in_ != nullptr);
75 assert(ar_ != nullptr);
76 memset(ar_.get(), 0, sizeof(mz_zip_archive));
77 size_t size = in_->size();
78 ar_->m_pIO_opaque = this;
79 ar_->m_pRead = istreamReadFunc;
80 mz_zip_reader_init(ar_.get(), size, 0);
81 valid("reading zip archive");
82 // figure out the archive_name (i.e. the zip folder all the other files are
83 // in) all lookups to getRecord will be prefixed by this folder
84 int n = mz_zip_reader_get_num_files(ar_.get());
85 if (n == 0) {
86 LOG(FATAL) << "archive does not contain any files";
87 }
88 size_t name_size = mz_zip_reader_get_filename(ar_.get(), 0, nullptr, 0);
89 valid("getting filename");
90 std::string buf(name_size, '\0');
91 mz_zip_reader_get_filename(ar_.get(), 0, &buf[0], name_size);
92 valid("getting filename");
93 auto pos = buf.find_first_of('/');
94 if (pos == std::string::npos) {
95 LOG(FATAL) << "file in archive is not in a subdirectory";
96 }
97 archive_name_ = buf.substr(0, pos);
98}
99
100size_t ZipReader::getRecordID(const std::string &name) {
101 std::stringstream ss;
102 ss << archive_name_ << "/" << name;
103 size_t result =
104 mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0);
105 if (ar_->m_last_error == MZ_ZIP_FILE_NOT_FOUND) {
106 LOG(FATAL) << "file not found: " << ss.str();
107 }
108 valid("locating file ", name.c_str());
109 return result;
110}
111
112bool ZipReader::hasRecord(const std::string &name) {
113 std::stringstream ss;
114 ss << archive_name_ << "/" << name;
115 mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0);
116 bool result = ar_->m_last_error != MZ_ZIP_FILE_NOT_FOUND;
117 if (!result) {
118 ar_->m_last_error = MZ_ZIP_NO_ERROR;
119 }
120 valid("attempting to locate file ", name.c_str());
121 return result;
122}
123
124std::string ZipReader::getRecord(const std::string &name) {
125 size_t key = getRecordID(name);
126 mz_zip_archive_file_stat stat;
127 mz_zip_reader_file_stat(ar_.get(), key, &stat);
128 valid("retrieving file meta-data for ", name.c_str());
129 std::string data;
130 data.resize(stat.m_uncomp_size);
131 mz_zip_reader_extract_to_mem(ar_.get(), key, &data[0], stat.m_uncomp_size, 0);
132 valid("reading file ", name.c_str());
133 return data;
134}
135
136void ZipReader::valid(const char *what, const char *info) {
137 auto err = mz_zip_get_last_error(ar_.get());
138 if (err != MZ_ZIP_NO_ERROR) {
139 LOG(FATAL) << "PytorchStreamReader failed " << what << info << ": "
140 << mz_zip_get_error_string(err);
141 }
142}
143
144size_t ostreamWriteFunc(void *pOpaque, mz_uint64 file_ofs, const void *pBuf,
145 size_t n) {
146 auto *self = static_cast<ZipWriter *>(pOpaque);
147 if (self->current_pos_ != file_ofs) {
148 // xxx - windows ostringstream refuses to seek to the end of an empty string
149 // so we workaround this by not calling seek unless necessary
150 // in the case of the first write (to the empty string) file_ofs and
151 // current_pos_ will be 0 and the seek won't occur.
152 self->out_->seekp(file_ofs);
153 if (!*self->out_) {
154 return 0;
155 }
156 }
157 self->out_->write(static_cast<const char *>(pBuf), n);
158 if (!*self->out_) {
159 return 0;
160 }
161 self->current_pos_ = file_ofs + n;
162 return n;
163}
164
165ZipWriter::ZipWriter(std::ostream *out, const std::string &archive_name)
166 : out_(out), finalized_{false}, ar_(glow::make_unique<mz_zip_archive>()),
167 archive_name_(archive_name) {
168 memset(ar_.get(), 0, sizeof(mz_zip_archive));
169 ar_->m_pIO_opaque = this;
170 ar_->m_pWrite = ostreamWriteFunc;
171 mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
172}
173
174void ZipWriter::writeRecord(const std::string &name, const void *data,
175 size_t size, bool compress) {
176 assert(!finalized_);
177 std::stringstream ss;
178 ss << archive_name_ << "/" << name;
179 const std::string full_name = ss.str();
180 std::string padding = getPadding(ar_->m_archive_size, full_name, size);
181 uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
182 mz_zip_writer_add_mem_ex_v2(ar_.get(), full_name.c_str(), data, size, nullptr,
183 0, flags, 0, 0, nullptr, padding.c_str(),
184 padding.size(), nullptr, 0);
185 valid("writing file ", name.c_str());
186}
187
188void ZipWriter::writeEndOfFile() {
189 assert(!finalized_);
190 finalized_ = true;
191 mz_zip_writer_finalize_archive(ar_.get());
192 mz_zip_writer_end(ar_.get());
193 valid("writing central directory for archive ", archive_name_.c_str());
194}
195
196void ZipWriter::valid(const char *what, const char *info) {
197 auto err = mz_zip_get_last_error(ar_.get());
198 if (err != MZ_ZIP_NO_ERROR) {
199 LOG(FATAL) << "ZipWriter failed " << what << info << ": "
200 << mz_zip_get_error_string(err);
201 }
202 if (!*out_) {
203 LOG(FATAL) << "ZipWriter failed " << what << info << ".";
204 }
205}
206
207ZipWriter::~ZipWriter() {
208 if (!finalized_) {
209 writeEndOfFile();
210 }
211}
212} // namespace glow
213