1 | #include <cstdio> |
2 | #include <cstring> |
3 | #include <cerrno> |
4 | #include <istream> |
5 | #include <ostream> |
6 | #include <fstream> |
7 | #include <algorithm> |
8 | #include <sys/stat.h> |
9 | #include <sys/types.h> |
10 | |
11 | |
12 | #include <c10/core/Allocator.h> |
13 | #include <c10/core/CPUAllocator.h> |
14 | #include <c10/core/Backend.h> |
15 | #include <c10/util/Exception.h> |
16 | |
17 | #include "caffe2/core/common.h" |
18 | #include "caffe2/core/logging.h" |
19 | #include "caffe2/serialize/file_adapter.h" |
20 | #include "caffe2/serialize/inline_container.h" |
21 | #include "caffe2/serialize/istream_adapter.h" |
22 | #include "caffe2/serialize/read_adapter_interface.h" |
23 | |
24 | #include "miniz.h" |
25 | |
26 | namespace caffe2 { |
27 | namespace serialize { |
28 | constexpr c10::string_view kDebugPklSuffix(".debug_pkl" ); |
29 | |
30 | size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) { |
31 | auto self = static_cast<PyTorchStreamReader*>(pOpaque); |
32 | return self->read(file_ofs, static_cast<char*>(pBuf), n); |
33 | } |
34 | |
35 | static std::string basename(const std::string& name) { |
36 | size_t start = 0; |
37 | for(size_t i = 0; i < name.size(); ++i) { |
38 | if (name[i] == '\\' || name[i] == '/') { |
39 | start = i + 1; |
40 | } |
41 | } |
42 | |
43 | if (start >= name.size()) |
44 | return "" ; |
45 | |
46 | size_t end = name.size(); |
47 | for(size_t i = end; i > start; --i) { |
48 | if (name[i - 1] == '.') { |
49 | end = i - 1; |
50 | break; |
51 | } |
52 | } |
53 | return name.substr(start, end - start); |
54 | } |
55 | |
56 | static std::string parentdir(const std::string& name) { |
57 | size_t end = name.find_last_of('/'); |
58 | if(end == std::string::npos) |
59 | end = name.find_last_of('\\'); |
60 | |
61 | if(end == std::string::npos) |
62 | return "" ; |
63 | |
64 | return name.substr(0, end); |
65 | } |
66 | |
67 | size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) { |
68 | return in_->read(pos, buf, n, "reading file" ); |
69 | } |
70 | |
71 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
72 | PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name) |
73 | : ar_(std::make_unique<mz_zip_archive>()), |
74 | in_(std::make_unique<FileAdapter>(file_name)) { |
75 | init(); |
76 | } |
77 | |
78 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
79 | PyTorchStreamReader::PyTorchStreamReader(std::istream* in) |
80 | : ar_(std::make_unique<mz_zip_archive>()), |
81 | in_(std::make_unique<IStreamAdapter>(in)) { |
82 | init(); |
83 | } |
84 | |
85 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
86 | PyTorchStreamReader::PyTorchStreamReader( |
87 | std::shared_ptr<ReadAdapterInterface> in) |
88 | : ar_(std::make_unique<mz_zip_archive>()), in_(std::move(in)) { |
89 | init(); |
90 | } |
91 | |
92 | void PyTorchStreamReader::init() { |
93 | AT_ASSERT(in_ != nullptr); |
94 | AT_ASSERT(ar_ != nullptr); |
95 | memset(ar_.get(), 0, sizeof(mz_zip_archive)); |
96 | |
97 | size_t size = in_->size(); |
98 | |
99 | // check for the old magic number, |
100 | constexpr size_t kMagicValueLength = 8; |
101 | if (size > kMagicValueLength) { |
102 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
103 | char buf[kMagicValueLength]; |
104 | read(0, buf, kMagicValueLength); |
105 | valid("checking magic number" ); |
106 | AT_ASSERTM( |
107 | memcmp("PYTORCH1" , buf, kMagicValueLength) != 0, |
108 | "File is an unsupported archive format from the preview release." ); |
109 | } |
110 | |
111 | ar_->m_pIO_opaque = this; |
112 | ar_->m_pRead = istream_read_func; |
113 | |
114 | mz_zip_reader_init(ar_.get(), size, 0); |
115 | valid("reading zip archive" ); |
116 | |
117 | // figure out the archive_name (i.e. the zip folder all the other files are in) |
118 | // all lookups to getRecord will be prefixed by this folder |
119 | int n = mz_zip_reader_get_num_files(ar_.get()); |
120 | if (n == 0) { |
121 | CAFFE_THROW("archive does not contain any files" ); |
122 | } |
123 | size_t name_size = mz_zip_reader_get_filename(ar_.get(), 0, nullptr, 0); |
124 | valid("getting filename" ); |
125 | std::string buf(name_size, '\0'); |
126 | mz_zip_reader_get_filename(ar_.get(), 0, &buf[0], name_size); |
127 | valid("getting filename" ); |
128 | auto pos = buf.find_first_of('/'); |
129 | if (pos == std::string::npos) { |
130 | CAFFE_THROW("file in archive is not in a subdirectory: " , buf); |
131 | } |
132 | archive_name_ = buf.substr(0, pos); |
133 | archive_name_plus_slash_ = archive_name_ + "/" ; |
134 | |
135 | // version check |
136 | at::DataPtr version_ptr; |
137 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
138 | size_t version_size; |
139 | if (hasRecord(".data/version" )) { |
140 | std::tie(version_ptr, version_size) = getRecord(".data/version" ); |
141 | } else { |
142 | TORCH_CHECK(hasRecord("version" )) |
143 | std::tie(version_ptr, version_size) = getRecord("version" ); |
144 | } |
145 | std::string version(static_cast<const char*>(version_ptr.get()), version_size); |
146 | try { |
147 | version_ = caffe2::stoull(version); |
148 | } catch (const std::invalid_argument &e) { |
149 | CAFFE_THROW("Couldn't parse the version " , |
150 | version, |
151 | " as Long Long." ); |
152 | } |
153 | // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
154 | if (version_ < kMinSupportedFileFormatVersion) { |
155 | CAFFE_THROW( |
156 | "Attempted to read a PyTorch file with version " , |
157 | c10::to_string(version_), |
158 | ", but the minimum supported version for reading is " , |
159 | c10::to_string(kMinSupportedFileFormatVersion), |
160 | ". Your PyTorch script module file is too old. Please regenerate it" , |
161 | " with latest version of PyTorch to mitigate this issue." ); |
162 | } |
163 | |
164 | // NOLINTNEXTLINE(clang-diagnostic-sign-compare) |
165 | if (version_ > kMaxSupportedFileFormatVersion) { |
166 | CAFFE_THROW( |
167 | "Attempted to read a PyTorch file with version " , |
168 | version_, |
169 | ", but the maximum supported version for reading is " , |
170 | kMaxSupportedFileFormatVersion, |
171 | ". The version of your PyTorch installation may be too old, " , |
172 | "please upgrade PyTorch to latest version to mitigate this issue." ); |
173 | } |
174 | } |
175 | |
176 | void PyTorchStreamReader::valid(const char* what, const char* info) { |
177 | const auto err = mz_zip_get_last_error(ar_.get()); |
178 | TORCH_CHECK( |
179 | err == MZ_ZIP_NO_ERROR, |
180 | "PytorchStreamReader failed " , |
181 | what, |
182 | info, |
183 | ": " , |
184 | mz_zip_get_error_string(err)); |
185 | } |
186 | |
187 | constexpr int = 30; |
188 | constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26; |
189 | constexpr int = 28; |
190 | |
191 | namespace detail { |
192 | size_t getPadding( |
193 | size_t cursor, |
194 | size_t filename_size, |
195 | size_t size, |
196 | std::string& padding_buf) { |
197 | size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size + |
198 | sizeof(mz_uint16) * 2; |
199 | if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) { |
200 | start += sizeof(mz_uint16) * 2; |
201 | if (size >= MZ_UINT32_MAX) { |
202 | start += 2 * sizeof(mz_uint64); |
203 | } |
204 | if (cursor >= MZ_UINT32_MAX) { |
205 | start += sizeof(mz_uint64); |
206 | } |
207 | } |
208 | size_t mod = start % kFieldAlignment; |
209 | size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod); |
210 | size_t padding_size = next_offset - start; |
211 | size_t padding_size_plus_fbxx = padding_size + 4; |
212 | if (padding_buf.size() < padding_size_plus_fbxx) { |
213 | padding_buf.append(padding_size_plus_fbxx - padding_buf.size(), 'Z'); |
214 | } |
215 | // zip extra encoding (key, size_of_extra_bytes) |
216 | padding_buf[0] = 'F'; |
217 | padding_buf[1] = 'B'; |
218 | padding_buf[2] = (uint8_t)padding_size; |
219 | padding_buf[3] = (uint8_t)(padding_size >> 8); |
220 | return padding_size_plus_fbxx; |
221 | } |
222 | } |
223 | |
224 | bool PyTorchStreamReader::hasRecord(const std::string& name) { |
225 | std::lock_guard<std::mutex> guard(reader_lock_); |
226 | |
227 | if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) { |
228 | return false; |
229 | } |
230 | std::string ss = archive_name_plus_slash_ + name; |
231 | mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0); |
232 | const mz_zip_error err = mz_zip_get_last_error(ar_.get()); |
233 | |
234 | if (err == MZ_ZIP_NO_ERROR) { |
235 | return true; |
236 | } else if (err == MZ_ZIP_FILE_NOT_FOUND) { |
237 | return false; |
238 | } else { |
239 | // A different error happened, raise it. |
240 | valid("attempting to locate file " , name.c_str()); |
241 | } |
242 | TORCH_INTERNAL_ASSERT(false, "should not reach here" ); |
243 | } |
244 | |
245 | std::vector<std::string> PyTorchStreamReader::getAllRecords() { |
246 | std::lock_guard<std::mutex> guard(reader_lock_); |
247 | mz_uint num_files = mz_zip_reader_get_num_files(ar_.get()); |
248 | std::vector<std::string> out; |
249 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
250 | char buf[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE]; |
251 | for (size_t i = 0; i < num_files; i++) { |
252 | mz_zip_reader_get_filename(ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE); |
253 | if (strncmp( |
254 | buf, |
255 | archive_name_plus_slash_.data(), |
256 | archive_name_plus_slash_.size()) != 0) { |
257 | CAFFE_THROW( |
258 | "file in archive is not in a subdirectory " , |
259 | archive_name_plus_slash_, |
260 | ": " , |
261 | buf); |
262 | } |
263 | if ((load_debug_symbol_) || |
264 | (!c10::string_view(buf + archive_name_plus_slash_.size()).ends_with(kDebugPklSuffix))) { |
265 | // NOLINTNEXTLINE(modernize-use-emplace) |
266 | out.push_back(buf + archive_name_plus_slash_.size()); |
267 | } |
268 | } |
269 | return out; |
270 | } |
271 | |
272 | const std::unordered_set<std::string>& |
273 | PyTorchStreamWriter::getAllWrittenRecords() { |
274 | return files_written_; |
275 | } |
276 | |
277 | size_t PyTorchStreamReader::getRecordID(const std::string& name) { |
278 | std::string ss = archive_name_plus_slash_ + name; |
279 | size_t result = mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0); |
280 | valid("locating file " , name.c_str()); |
281 | return result; |
282 | } |
283 | |
284 | // return dataptr, size |
285 | std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) { |
286 | std::lock_guard<std::mutex> guard(reader_lock_); |
287 | if ((!load_debug_symbol_) && c10::string_view(name).ends_with(kDebugPklSuffix)) { |
288 | at::DataPtr retval; |
289 | return std::make_tuple(std::move(retval), 0); |
290 | } |
291 | size_t key = getRecordID(name); |
292 | mz_zip_archive_file_stat stat; |
293 | mz_zip_reader_file_stat(ar_.get(), key, &stat); |
294 | valid("retrieving file meta-data for " , name.c_str()); |
295 | at::DataPtr retval = c10::GetCPUAllocator()->allocate(stat.m_uncomp_size); |
296 | mz_zip_reader_extract_to_mem(ar_.get(), key, retval.get(), stat.m_uncomp_size, 0); |
297 | valid("reading file " , name.c_str()); |
298 | |
299 | return std::make_tuple(std::move(retval), stat.m_uncomp_size); |
300 | } |
301 | |
302 | static int64_t read_le_16(uint8_t* buf) { |
303 | return buf[0] + (buf[1] << 8); |
304 | } |
305 | |
306 | size_t PyTorchStreamReader::getRecordOffset(const std::string& name) { |
307 | std::lock_guard<std::mutex> guard(reader_lock_); |
308 | mz_zip_archive_file_stat stat; |
309 | mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat); |
310 | valid("retrieving file meta-data for " , name.c_str()); |
311 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
312 | uint8_t [MZ_ZIP_LOCAL_DIR_HEADER_SIZE]; |
313 | in_->read( |
314 | stat.m_local_header_ofs, |
315 | local_header, |
316 | MZ_ZIP_LOCAL_DIR_HEADER_SIZE, |
317 | "reading file header" ); |
318 | size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS); |
319 | size_t = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS); |
320 | return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len; |
321 | } |
322 | |
323 | |
324 | PyTorchStreamReader::~PyTorchStreamReader() { |
325 | mz_zip_clear_last_error(ar_.get()); |
326 | mz_zip_reader_end(ar_.get()); |
327 | valid("closing reader for archive " , archive_name_.c_str()); |
328 | } |
329 | |
330 | size_t ostream_write_func( |
331 | void* pOpaque, |
332 | mz_uint64 file_ofs, |
333 | const void* pBuf, |
334 | size_t n) { |
335 | auto self = static_cast<PyTorchStreamWriter*>(pOpaque); |
336 | if (self->current_pos_ != file_ofs) { |
337 | CAFFE_THROW("unexpected pos " , self->current_pos_, " vs " , file_ofs); |
338 | } |
339 | size_t ret = self->writer_func_(pBuf, n); |
340 | if (n != ret) { |
341 | self->err_seen_ = true; |
342 | } |
343 | self->current_pos_ += ret; |
344 | return ret; |
345 | } |
346 | |
347 | PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name) |
348 | : archive_name_(basename(file_name)) { |
349 | setup(file_name); |
350 | } |
351 | |
352 | PyTorchStreamWriter::PyTorchStreamWriter( |
353 | const std::function<size_t(const void*, size_t)> writer_func) |
354 | : archive_name_("archive" ), |
355 | writer_func_(writer_func) { |
356 | setup(archive_name_); |
357 | } |
358 | |
359 | void PyTorchStreamWriter::setup(const string& file_name) { |
360 | ar_ = std::make_unique<mz_zip_archive>(); |
361 | memset(ar_.get(), 0, sizeof(mz_zip_archive)); |
362 | archive_name_plus_slash_ = archive_name_ + "/" ; // for writeRecord(). |
363 | |
364 | if (archive_name_.size() == 0) { |
365 | CAFFE_THROW("invalid file name: " , file_name); |
366 | } |
367 | if (!writer_func_) { |
368 | file_stream_.open( |
369 | file_name, |
370 | std::ofstream::out | std::ofstream::trunc | std::ofstream::binary); |
371 | valid("opening archive " , file_name.c_str()); |
372 | |
373 | const std::string dir_name = parentdir(file_name); |
374 | if(!dir_name.empty()) { |
375 | struct stat st; |
376 | bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR)); |
377 | TORCH_CHECK(dir_exists, "Parent directory " , dir_name, " does not exist." ); |
378 | } |
379 | TORCH_CHECK(file_stream_, "File " , file_name, " cannot be opened." ); |
380 | writer_func_ = [this](const void* buf, size_t nbytes) -> size_t { |
381 | file_stream_.write(static_cast<const char*>(buf), nbytes); |
382 | return !file_stream_ ? 0 : nbytes; |
383 | }; |
384 | } |
385 | |
386 | ar_->m_pIO_opaque = this; |
387 | ar_->m_pWrite = ostream_write_func; |
388 | |
389 | mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64); |
390 | valid("initializing archive " , file_name.c_str()); |
391 | } |
392 | |
393 | void PyTorchStreamWriter::setMinVersion(const uint64_t version) { |
394 | version_ = std::max(version, version_); |
395 | } |
396 | |
397 | void PyTorchStreamWriter::writeRecord( |
398 | const std::string& name, |
399 | const void* data, |
400 | size_t size, |
401 | bool compress) { |
402 | AT_ASSERT(!finalized_); |
403 | AT_ASSERT(!archive_name_plus_slash_.empty()); |
404 | TORCH_INTERNAL_ASSERT( |
405 | files_written_.count(name) == 0, "Tried to serialize file twice: " , name); |
406 | std::string full_name = archive_name_plus_slash_ + name; |
407 | size_t padding_size = |
408 | detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_); |
409 | uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0; |
410 | mz_zip_writer_add_mem_ex_v2( |
411 | ar_.get(), |
412 | full_name.c_str(), |
413 | data, |
414 | size, |
415 | nullptr, |
416 | 0, |
417 | flags, |
418 | 0, |
419 | 0, |
420 | nullptr, |
421 | padding_.c_str(), |
422 | padding_size, |
423 | nullptr, |
424 | 0); |
425 | valid("writing file " , name.c_str()); |
426 | files_written_.insert(name); |
427 | } |
428 | |
429 | void PyTorchStreamWriter::writeEndOfFile() { |
430 | // Ensurers that finalized is set to true even |
431 | // exception is raised during the method call. |
432 | // I.e. even partial call to writeEndOfFile() should mark |
433 | // file as finalized, otherwise double exception raised from |
434 | // destructor would would result in `std::terminate()` |
435 | // See https://github.com/pytorch/pytorch/issues/87997/ |
436 | struct Finalizer { |
437 | Finalizer(bool& var): var_(var) {} |
438 | ~Finalizer() { |
439 | var_ = true; |
440 | } |
441 | private: |
442 | bool& var_; |
443 | } f(finalized_); |
444 | |
445 | auto allRecords = getAllWrittenRecords(); |
446 | // If no ".data/version" or "version" record in the output model, rewrites version info |
447 | if(allRecords.find(".data/version" ) == allRecords.end() && allRecords.find("version" ) == allRecords.end()) { |
448 | std::string version = c10::to_string(version_); |
449 | version.push_back('\n'); |
450 | if (version_ >= 0x6L) { |
451 | writeRecord(".data/version" , version.c_str(), version.size()); |
452 | } else { |
453 | writeRecord("version" , version.c_str(), version.size()); |
454 | } |
455 | } |
456 | |
457 | AT_ASSERT(!finalized_); |
458 | finalized_ = true; |
459 | |
460 | mz_zip_writer_finalize_archive(ar_.get()); |
461 | mz_zip_writer_end(ar_.get()); |
462 | valid("writing central directory for archive " , archive_name_.c_str()); |
463 | if (file_stream_.is_open()) { |
464 | file_stream_.close(); |
465 | } |
466 | } |
467 | |
468 | void PyTorchStreamWriter::valid(const char* what, const char* info) { |
469 | auto err = mz_zip_get_last_error(ar_.get()); |
470 | if (err != MZ_ZIP_NO_ERROR) { |
471 | CAFFE_THROW( |
472 | "PytorchStreamWriter failed " , |
473 | what, |
474 | info, |
475 | ": " , |
476 | mz_zip_get_error_string(err)); |
477 | } |
478 | if (err_seen_) { |
479 | CAFFE_THROW("PytorchStreamWriter failed " , what, info, "." ); |
480 | } |
481 | } |
482 | |
483 | // NOLINTNEXTLINE(bugprone-exception-escape) |
484 | PyTorchStreamWriter::~PyTorchStreamWriter() { |
485 | if (!finalized_) { |
486 | writeEndOfFile(); |
487 | } |
488 | } |
489 | |
490 | } // namespace serialize |
491 | } // namespace caffe2 |
492 | |