1#include <cstdio>
2#include <cstring>
3#include <cerrno>
4#include <istream>
5#include <ostream>
6#include <fstream>
7#include <algorithm>
8#include <sys/stat.h>
9#include <sys/types.h>
10
11
12#include <c10/core/Allocator.h>
13#include <c10/core/CPUAllocator.h>
14#include <c10/core/Backend.h>
15#include <c10/util/Exception.h>
16
17#include "caffe2/core/common.h"
18#include "caffe2/core/logging.h"
19#include "caffe2/serialize/file_adapter.h"
20#include "caffe2/serialize/inline_container.h"
21#include "caffe2/serialize/istream_adapter.h"
22#include "caffe2/serialize/read_adapter_interface.h"
23
24#include "miniz.h"
25
26namespace caffe2 {
27namespace serialize {
28constexpr c10::string_view kDebugPklSuffix(".debug_pkl");
29
30size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) {
31 auto self = static_cast<PyTorchStreamReader*>(pOpaque);
32 return self->read(file_ofs, static_cast<char*>(pBuf), n);
33}
34
35static std::string basename(const std::string& name) {
36 size_t start = 0;
37 for(size_t i = 0; i < name.size(); ++i) {
38 if (name[i] == '\\' || name[i] == '/') {
39 start = i + 1;
40 }
41 }
42
43 if (start >= name.size())
44 return "";
45
46 size_t end = name.size();
47 for(size_t i = end; i > start; --i) {
48 if (name[i - 1] == '.') {
49 end = i - 1;
50 break;
51 }
52 }
53 return name.substr(start, end - start);
54}
55
56static std::string parentdir(const std::string& name) {
57 size_t end = name.find_last_of('/');
58 if(end == std::string::npos)
59 end = name.find_last_of('\\');
60
61 if(end == std::string::npos)
62 return "";
63
64 return name.substr(0, end);
65}
66
67size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
68 return in_->read(pos, buf, n, "reading file");
69}
70
71// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
72PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name)
73 : ar_(std::make_unique<mz_zip_archive>()),
74 in_(std::make_unique<FileAdapter>(file_name)) {
75 init();
76}
77
78// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
79PyTorchStreamReader::PyTorchStreamReader(std::istream* in)
80 : ar_(std::make_unique<mz_zip_archive>()),
81 in_(std::make_unique<IStreamAdapter>(in)) {
82 init();
83}
84
85// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
86PyTorchStreamReader::PyTorchStreamReader(
87 std::shared_ptr<ReadAdapterInterface> in)
88 : ar_(std::make_unique<mz_zip_archive>()), in_(std::move(in)) {
89 init();
90}
91
92void PyTorchStreamReader::init() {
93 AT_ASSERT(in_ != nullptr);
94 AT_ASSERT(ar_ != nullptr);
95 memset(ar_.get(), 0, sizeof(mz_zip_archive));
96
97 size_t size = in_->size();
98
99 // check for the old magic number,
100 constexpr size_t kMagicValueLength = 8;
101 if (size > kMagicValueLength) {
102 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
103 char buf[kMagicValueLength];
104 read(0, buf, kMagicValueLength);
105 valid("checking magic number");
106 AT_ASSERTM(
107 memcmp("PYTORCH1", buf, kMagicValueLength) != 0,
108 "File is an unsupported archive format from the preview release.");
109 }
110
111 ar_->m_pIO_opaque = this;
112 ar_->m_pRead = istream_read_func;
113
114 mz_zip_reader_init(ar_.get(), size, 0);
115 valid("reading zip archive");
116
117 // figure out the archive_name (i.e. the zip folder all the other files are in)
118 // all lookups to getRecord will be prefixed by this folder
119 int n = mz_zip_reader_get_num_files(ar_.get());
120 if (n == 0) {
121 CAFFE_THROW("archive does not contain any files");
122 }
123 size_t name_size = mz_zip_reader_get_filename(ar_.get(), 0, nullptr, 0);
124 valid("getting filename");
125 std::string buf(name_size, '\0');
126 mz_zip_reader_get_filename(ar_.get(), 0, &buf[0], name_size);
127 valid("getting filename");
128 auto pos = buf.find_first_of('/');
129 if (pos == std::string::npos) {
130 CAFFE_THROW("file in archive is not in a subdirectory: ", buf);
131 }
132 archive_name_ = buf.substr(0, pos);
133 archive_name_plus_slash_ = archive_name_ + "/";
134
135 // version check
136 at::DataPtr version_ptr;
137 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
138 size_t version_size;
139 if (hasRecord(".data/version")) {
140 std::tie(version_ptr, version_size) = getRecord(".data/version");
141 } else {
142 TORCH_CHECK(hasRecord("version"))
143 std::tie(version_ptr, version_size) = getRecord("version");
144 }
145 std::string version(static_cast<const char*>(version_ptr.get()), version_size);
146 try {
147 version_ = caffe2::stoull(version);
148 } catch (const std::invalid_argument &e) {
149 CAFFE_THROW("Couldn't parse the version ",
150 version,
151 " as Long Long.");
152 }
153 // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
154 if (version_ < kMinSupportedFileFormatVersion) {
155 CAFFE_THROW(
156 "Attempted to read a PyTorch file with version ",
157 c10::to_string(version_),
158 ", but the minimum supported version for reading is ",
159 c10::to_string(kMinSupportedFileFormatVersion),
160 ". Your PyTorch script module file is too old. Please regenerate it",
161 " with latest version of PyTorch to mitigate this issue.");
162 }
163
164 // NOLINTNEXTLINE(clang-diagnostic-sign-compare)
165 if (version_ > kMaxSupportedFileFormatVersion) {
166 CAFFE_THROW(
167 "Attempted to read a PyTorch file with version ",
168 version_,
169 ", but the maximum supported version for reading is ",
170 kMaxSupportedFileFormatVersion,
171 ". The version of your PyTorch installation may be too old, ",
172 "please upgrade PyTorch to latest version to mitigate this issue.");
173 }
174}
175
176void PyTorchStreamReader::valid(const char* what, const char* info) {
177 const auto err = mz_zip_get_last_error(ar_.get());
178 TORCH_CHECK(
179 err == MZ_ZIP_NO_ERROR,
180 "PytorchStreamReader failed ",
181 what,
182 info,
183 ": ",
184 mz_zip_get_error_string(err));
185}
186
187constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
188constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
189constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
190
191namespace detail {
192size_t getPadding(
193 size_t cursor,
194 size_t filename_size,
195 size_t size,
196 std::string& padding_buf) {
197 size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
198 sizeof(mz_uint16) * 2;
199 if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
200 start += sizeof(mz_uint16) * 2;
201 if (size >= MZ_UINT32_MAX) {
202 start += 2 * sizeof(mz_uint64);
203 }
204 if (cursor >= MZ_UINT32_MAX) {
205 start += sizeof(mz_uint64);
206 }
207 }
208 size_t mod = start % kFieldAlignment;
209 size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
210 size_t padding_size = next_offset - start;
211 size_t padding_size_plus_fbxx = padding_size + 4;
212 if (padding_buf.size() < padding_size_plus_fbxx) {
213 padding_buf.append(padding_size_plus_fbxx - padding_buf.size(), 'Z');
214 }
215 // zip extra encoding (key, size_of_extra_bytes)
216 padding_buf[0] = 'F';
217 padding_buf[1] = 'B';
218 padding_buf[2] = (uint8_t)padding_size;
219 padding_buf[3] = (uint8_t)(padding_size >> 8);
220 return padding_size_plus_fbxx;
221}
222}
223
224bool PyTorchStreamReader::hasRecord(const std::string& name) {
225 std::lock_guard<std::mutex> guard(reader_lock_);
226
227 if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
228 return false;
229 }
230 std::string ss = archive_name_plus_slash_ + name;
231 mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
232 const mz_zip_error err = mz_zip_get_last_error(ar_.get());
233
234 if (err == MZ_ZIP_NO_ERROR) {
235 return true;
236 } else if (err == MZ_ZIP_FILE_NOT_FOUND) {
237 return false;
238 } else {
239 // A different error happened, raise it.
240 valid("attempting to locate file ", name.c_str());
241 }
242 TORCH_INTERNAL_ASSERT(false, "should not reach here");
243}
244
245std::vector<std::string> PyTorchStreamReader::getAllRecords() {
246 std::lock_guard<std::mutex> guard(reader_lock_);
247 mz_uint num_files = mz_zip_reader_get_num_files(ar_.get());
248 std::vector<std::string> out;
249 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
250 char buf[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE];
251 for (size_t i = 0; i < num_files; i++) {
252 mz_zip_reader_get_filename(ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
253 if (strncmp(
254 buf,
255 archive_name_plus_slash_.data(),
256 archive_name_plus_slash_.size()) != 0) {
257 CAFFE_THROW(
258 "file in archive is not in a subdirectory ",
259 archive_name_plus_slash_,
260 ": ",
261 buf);
262 }
263 if ((load_debug_symbol_) ||
264 (!c10::string_view(buf + archive_name_plus_slash_.size()).ends_with(kDebugPklSuffix))) {
265 // NOLINTNEXTLINE(modernize-use-emplace)
266 out.push_back(buf + archive_name_plus_slash_.size());
267 }
268 }
269 return out;
270}
271
272const std::unordered_set<std::string>&
273PyTorchStreamWriter::getAllWrittenRecords() {
274 return files_written_;
275}
276
277size_t PyTorchStreamReader::getRecordID(const std::string& name) {
278 std::string ss = archive_name_plus_slash_ + name;
279 size_t result = mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
280 valid("locating file ", name.c_str());
281 return result;
282}
283
284// return dataptr, size
285std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
286 std::lock_guard<std::mutex> guard(reader_lock_);
287 if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) {
288 at::DataPtr retval;
289 return std::make_tuple(std::move(retval), 0);
290 }
291 size_t key = getRecordID(name);
292 mz_zip_archive_file_stat stat;
293 mz_zip_reader_file_stat(ar_.get(), key, &stat);
294 valid("retrieving file meta-data for ", name.c_str());
295 at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size);
296 mz_zip_reader_extract_to_mem(ar_.get(), key, retval.get(), stat.m_uncomp_size, 0);
297 valid("reading file ", name.c_str());
298
299 return std::make_tuple(std::move(retval), stat.m_uncomp_size);
300}
301
302static int64_t read_le_16(uint8_t* buf) {
303 return buf[0] + (buf[1] << 8);
304}
305
306size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
307 std::lock_guard<std::mutex> guard(reader_lock_);
308 mz_zip_archive_file_stat stat;
309 mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
310 valid("retrieving file meta-data for ", name.c_str());
311 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
312 uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
313 in_->read(
314 stat.m_local_header_ofs,
315 local_header,
316 MZ_ZIP_LOCAL_DIR_HEADER_SIZE,
317 "reading file header");
318 size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
319 size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
320 return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
321}
322
323
324PyTorchStreamReader::~PyTorchStreamReader() {
325 mz_zip_clear_last_error(ar_.get());
326 mz_zip_reader_end(ar_.get());
327 valid("closing reader for archive ", archive_name_.c_str());
328}
329
330size_t ostream_write_func(
331 void* pOpaque,
332 mz_uint64 file_ofs,
333 const void* pBuf,
334 size_t n) {
335 auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
336 if (self->current_pos_ != file_ofs) {
337 CAFFE_THROW("unexpected pos ", self->current_pos_, " vs ", file_ofs);
338 }
339 size_t ret = self->writer_func_(pBuf, n);
340 if (n != ret) {
341 self->err_seen_ = true;
342 }
343 self->current_pos_ += ret;
344 return ret;
345}
346
347PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name)
348 : archive_name_(basename(file_name)) {
349 setup(file_name);
350}
351
352PyTorchStreamWriter::PyTorchStreamWriter(
353 const std::function<size_t(const void*, size_t)> writer_func)
354 : archive_name_("archive"),
355 writer_func_(writer_func) {
356 setup(archive_name_);
357}
358
359void PyTorchStreamWriter::setup(const string& file_name) {
360 ar_ = std::make_unique<mz_zip_archive>();
361 memset(ar_.get(), 0, sizeof(mz_zip_archive));
362 archive_name_plus_slash_ = archive_name_ + "/"; // for writeRecord().
363
364 if (archive_name_.size() == 0) {
365 CAFFE_THROW("invalid file name: ", file_name);
366 }
367 if (!writer_func_) {
368 file_stream_.open(
369 file_name,
370 std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
371 valid("opening archive ", file_name.c_str());
372
373 const std::string dir_name = parentdir(file_name);
374 if(!dir_name.empty()) {
375 struct stat st;
376 bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
377 TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist.");
378 }
379 TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
380 writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
381 file_stream_.write(static_cast<const char*>(buf), nbytes);
382 return !file_stream_ ? 0 : nbytes;
383 };
384 }
385
386 ar_->m_pIO_opaque = this;
387 ar_->m_pWrite = ostream_write_func;
388
389 mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
390 valid("initializing archive ", file_name.c_str());
391}
392
393void PyTorchStreamWriter::setMinVersion(const uint64_t version) {
394 version_ = std::max(version, version_);
395}
396
397void PyTorchStreamWriter::writeRecord(
398 const std::string& name,
399 const void* data,
400 size_t size,
401 bool compress) {
402 AT_ASSERT(!finalized_);
403 AT_ASSERT(!archive_name_plus_slash_.empty());
404 TORCH_INTERNAL_ASSERT(
405 files_written_.count(name) == 0, "Tried to serialize file twice: ", name);
406 std::string full_name = archive_name_plus_slash_ + name;
407 size_t padding_size =
408 detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
409 uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
410 mz_zip_writer_add_mem_ex_v2(
411 ar_.get(),
412 full_name.c_str(),
413 data,
414 size,
415 nullptr,
416 0,
417 flags,
418 0,
419 0,
420 nullptr,
421 padding_.c_str(),
422 padding_size,
423 nullptr,
424 0);
425 valid("writing file ", name.c_str());
426 files_written_.insert(name);
427}
428
429void PyTorchStreamWriter::writeEndOfFile() {
430 // Ensurers that finalized is set to true even
431 // exception is raised during the method call.
432 // I.e. even partial call to writeEndOfFile() should mark
433 // file as finalized, otherwise double exception raised from
434 // destructor would would result in `std::terminate()`
435 // See https://github.com/pytorch/pytorch/issues/87997/
436 struct Finalizer {
437 Finalizer(bool& var): var_(var) {}
438 ~Finalizer() {
439 var_ = true;
440 }
441 private:
442 bool& var_;
443 } f(finalized_);
444
445 auto allRecords = getAllWrittenRecords();
446 // If no ".data/version" or "version" record in the output model, rewrites version info
447 if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) {
448 std::string version = c10::to_string(version_);
449 version.push_back('\n');
450 if (version_ >= 0x6L) {
451 writeRecord(".data/version", version.c_str(), version.size());
452 } else {
453 writeRecord("version", version.c_str(), version.size());
454 }
455 }
456
457 AT_ASSERT(!finalized_);
458 finalized_ = true;
459
460 mz_zip_writer_finalize_archive(ar_.get());
461 mz_zip_writer_end(ar_.get());
462 valid("writing central directory for archive ", archive_name_.c_str());
463 if (file_stream_.is_open()) {
464 file_stream_.close();
465 }
466}
467
468void PyTorchStreamWriter::valid(const char* what, const char* info) {
469 auto err = mz_zip_get_last_error(ar_.get());
470 if (err != MZ_ZIP_NO_ERROR) {
471 CAFFE_THROW(
472 "PytorchStreamWriter failed ",
473 what,
474 info,
475 ": ",
476 mz_zip_get_error_string(err));
477 }
478 if (err_seen_) {
479 CAFFE_THROW("PytorchStreamWriter failed ", what, info, ".");
480 }
481}
482
483// NOLINTNEXTLINE(bugprone-exception-escape)
484PyTorchStreamWriter::~PyTorchStreamWriter() {
485 if (!finalized_) {
486 writeEndOfFile();
487 }
488}
489
490} // namespace serialize
491} // namespace caffe2
492