1 | /** |
2 | * Copyright (c) Glow Contributors. See CONTRIBUTORS file. |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | */ |
16 | |
17 | #include "glow/Support/ZipUtils.h" |
18 | #include "glow/Support/Memory.h" |
19 | |
20 | #include "llvm/ADT/STLExtras.h" |
21 | |
22 | #include <sstream> |
23 | |
24 | namespace glow { |
25 | |
26 | namespace { |
27 | constexpr int = 30; |
28 | constexpr uint64_t kFieldAlignment = 64; |
29 | |
30 | static std::string getPadding(size_t cursor, const std::string &filename, |
31 | size_t size) { |
32 | size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename.size() + |
33 | sizeof(mz_uint16) * 2; |
34 | if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) { |
35 | start += sizeof(mz_uint16) * 2; |
36 | if (size >= MZ_UINT32_MAX) { |
37 | start += 2 * sizeof(mz_uint64); |
38 | } |
39 | if (cursor >= MZ_UINT32_MAX) { |
40 | start += sizeof(mz_uint64); |
41 | } |
42 | } |
43 | size_t mod = start % kFieldAlignment; |
44 | size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod); |
45 | size_t padding_size = next_offset - start; |
46 | std::string buf(padding_size + 4, 'Z'); |
47 | // zip extra encoding (key, size_of_extra_bytes) |
48 | buf[0] = 'F'; |
49 | buf[1] = 'B'; |
50 | buf[2] = (uint8_t)padding_size; |
51 | buf[3] = (uint8_t)(padding_size >> 8); |
52 | return buf; |
53 | } |
54 | } // namespace |
55 | |
56 | size_t istreamReadFunc(void *pOpaque, mz_uint64 file_ofs, void *pBuf, |
57 | size_t n) { |
58 | auto self = static_cast<ZipReader *>(pOpaque); |
59 | return self->read(file_ofs, static_cast<char *>(pBuf), n); |
60 | } |
61 | |
62 | ZipReader::ZipReader(const std::string &file_name) |
63 | : ar_(glow::make_unique<mz_zip_archive>()), |
64 | in_(glow::make_unique<FileAdapter>(file_name)) { |
65 | init(); |
66 | } |
67 | |
68 | ZipReader::~ZipReader() { |
69 | mz_zip_reader_end(ar_.get()); |
70 | valid("closing reader for archive " , archive_name_.c_str()); |
71 | } |
72 | |
73 | void ZipReader::init() { |
74 | assert(in_ != nullptr); |
75 | assert(ar_ != nullptr); |
76 | memset(ar_.get(), 0, sizeof(mz_zip_archive)); |
77 | size_t size = in_->size(); |
78 | ar_->m_pIO_opaque = this; |
79 | ar_->m_pRead = istreamReadFunc; |
80 | mz_zip_reader_init(ar_.get(), size, 0); |
81 | valid("reading zip archive" ); |
82 | // figure out the archive_name (i.e. the zip folder all the other files are |
83 | // in) all lookups to getRecord will be prefixed by this folder |
84 | int n = mz_zip_reader_get_num_files(ar_.get()); |
85 | if (n == 0) { |
86 | LOG(FATAL) << "archive does not contain any files" ; |
87 | } |
88 | size_t name_size = mz_zip_reader_get_filename(ar_.get(), 0, nullptr, 0); |
89 | valid("getting filename" ); |
90 | std::string buf(name_size, '\0'); |
91 | mz_zip_reader_get_filename(ar_.get(), 0, &buf[0], name_size); |
92 | valid("getting filename" ); |
93 | auto pos = buf.find_first_of('/'); |
94 | if (pos == std::string::npos) { |
95 | LOG(FATAL) << "file in archive is not in a subdirectory" ; |
96 | } |
97 | archive_name_ = buf.substr(0, pos); |
98 | } |
99 | |
100 | size_t ZipReader::getRecordID(const std::string &name) { |
101 | std::stringstream ss; |
102 | ss << archive_name_ << "/" << name; |
103 | size_t result = |
104 | mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0); |
105 | if (ar_->m_last_error == MZ_ZIP_FILE_NOT_FOUND) { |
106 | LOG(FATAL) << "file not found: " << ss.str(); |
107 | } |
108 | valid("locating file " , name.c_str()); |
109 | return result; |
110 | } |
111 | |
112 | bool ZipReader::hasRecord(const std::string &name) { |
113 | std::stringstream ss; |
114 | ss << archive_name_ << "/" << name; |
115 | mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0); |
116 | bool result = ar_->m_last_error != MZ_ZIP_FILE_NOT_FOUND; |
117 | if (!result) { |
118 | ar_->m_last_error = MZ_ZIP_NO_ERROR; |
119 | } |
120 | valid("attempting to locate file " , name.c_str()); |
121 | return result; |
122 | } |
123 | |
124 | std::string ZipReader::getRecord(const std::string &name) { |
125 | size_t key = getRecordID(name); |
126 | mz_zip_archive_file_stat stat; |
127 | mz_zip_reader_file_stat(ar_.get(), key, &stat); |
128 | valid("retrieving file meta-data for " , name.c_str()); |
129 | std::string data; |
130 | data.resize(stat.m_uncomp_size); |
131 | mz_zip_reader_extract_to_mem(ar_.get(), key, &data[0], stat.m_uncomp_size, 0); |
132 | valid("reading file " , name.c_str()); |
133 | return data; |
134 | } |
135 | |
136 | void ZipReader::valid(const char *what, const char *info) { |
137 | auto err = mz_zip_get_last_error(ar_.get()); |
138 | if (err != MZ_ZIP_NO_ERROR) { |
139 | LOG(FATAL) << "PytorchStreamReader failed " << what << info << ": " |
140 | << mz_zip_get_error_string(err); |
141 | } |
142 | } |
143 | |
144 | size_t ostreamWriteFunc(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, |
145 | size_t n) { |
146 | auto *self = static_cast<ZipWriter *>(pOpaque); |
147 | if (self->current_pos_ != file_ofs) { |
148 | // xxx - windows ostringstream refuses to seek to the end of an empty string |
149 | // so we workaround this by not calling seek unless necessary |
150 | // in the case of the first write (to the empty string) file_ofs and |
151 | // current_pos_ will be 0 and the seek won't occur. |
152 | self->out_->seekp(file_ofs); |
153 | if (!*self->out_) { |
154 | return 0; |
155 | } |
156 | } |
157 | self->out_->write(static_cast<const char *>(pBuf), n); |
158 | if (!*self->out_) { |
159 | return 0; |
160 | } |
161 | self->current_pos_ = file_ofs + n; |
162 | return n; |
163 | } |
164 | |
165 | ZipWriter::ZipWriter(std::ostream *out, const std::string &archive_name) |
166 | : out_(out), finalized_{false}, ar_(glow::make_unique<mz_zip_archive>()), |
167 | archive_name_(archive_name) { |
168 | memset(ar_.get(), 0, sizeof(mz_zip_archive)); |
169 | ar_->m_pIO_opaque = this; |
170 | ar_->m_pWrite = ostreamWriteFunc; |
171 | mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64); |
172 | } |
173 | |
174 | void ZipWriter::writeRecord(const std::string &name, const void *data, |
175 | size_t size, bool compress) { |
176 | assert(!finalized_); |
177 | std::stringstream ss; |
178 | ss << archive_name_ << "/" << name; |
179 | const std::string full_name = ss.str(); |
180 | std::string padding = getPadding(ar_->m_archive_size, full_name, size); |
181 | uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0; |
182 | mz_zip_writer_add_mem_ex_v2(ar_.get(), full_name.c_str(), data, size, nullptr, |
183 | 0, flags, 0, 0, nullptr, padding.c_str(), |
184 | padding.size(), nullptr, 0); |
185 | valid("writing file " , name.c_str()); |
186 | } |
187 | |
188 | void ZipWriter::writeEndOfFile() { |
189 | assert(!finalized_); |
190 | finalized_ = true; |
191 | mz_zip_writer_finalize_archive(ar_.get()); |
192 | mz_zip_writer_end(ar_.get()); |
193 | valid("writing central directory for archive " , archive_name_.c_str()); |
194 | } |
195 | |
196 | void ZipWriter::valid(const char *what, const char *info) { |
197 | auto err = mz_zip_get_last_error(ar_.get()); |
198 | if (err != MZ_ZIP_NO_ERROR) { |
199 | LOG(FATAL) << "ZipWriter failed " << what << info << ": " |
200 | << mz_zip_get_error_string(err); |
201 | } |
202 | if (!*out_) { |
203 | LOG(FATAL) << "ZipWriter failed " << what << info << "." ; |
204 | } |
205 | } |
206 | |
207 | ZipWriter::~ZipWriter() { |
208 | if (!finalized_) { |
209 | writeEndOfFile(); |
210 | } |
211 | } |
212 | } // namespace glow |
213 | |