1#pragma once
2
3#include <cerrno>
4#include <cstdio>
5#include <cstring>
6#include <fstream>
7#include <istream>
8#include <mutex>
9#include <ostream>
10#include <unordered_set>
11
12#include <c10/core/Allocator.h>
13#include <c10/core/Backend.h>
14
15#include "caffe2/serialize/istream_adapter.h"
16#include "caffe2/serialize/read_adapter_interface.h"
17#include "caffe2/serialize/versions.h"
18
19extern "C" {
20typedef struct mz_zip_archive mz_zip_archive;
21}
22
23// PyTorch containers are a special zip archive with the following layout
24// archive_name.zip contains:
25// archive_name/
26// version # a file with a single decimal number written in ascii,
27// # used to establish the version of the archive format
28// model.json # overall model description, this is a json output of
29// # ModelDef from torch.proto
30// # the following names are by convention only, model.json will
31// # refer to these files by full names
32// tensors/
33// 0 # flat storage for tensor data, meta-data about shapes, etc. is
34// # in model.json
35// 1
36// ...
37// # code entries will only exist for modules that have methods attached
38// code/
39// archive_name.py # serialized torch script code (python syntax, using
40// PythonPrint) archive_name_my_submodule.py # submodules have separate
41// files
42//
43// The PyTorchStreamWriter also ensures additional useful properties for these
44// files
45// 1. All files are stored uncompressed.
46// 2. All files in the archive are aligned to 64 byte boundaries such that
47// it is possible to mmap the entire file and get an aligned pointer to
48// tensor data.
49// 3. We universally write in ZIP64 format for consistency.
50
51// The PyTorchStreamReader also provides additional properties:
52// 1. It can read zip files that are created with common
53// zip tools. This means that even though our writer doesn't compress files,
54// the reader can still read files that were compressed.
55// 2. It provides a getRecordOffset function which returns the offset into the
56// raw file where file data lives. If the file was written with
57// PyTorchStreamWriter it is guaranteed to be 64 byte aligned.
58
59// PyTorchReader/Writer handle checking the version number on the archive format
60// and ensure that all files are written to a archive_name directory so they
61// unzip cleanly.
62
63// When developing this format we want to pay particular attention to the
64// following use cases:
65//
66// -- Reading --
67// 1) Reading with full random access
68// a) Reading with file api's such as fread()
69// b) mmaping the file and jumping around the mapped region
70// 2) Reading with 1-pass sequential access
71// -> A reader will need to build up a data structure of parsed structures
72// as it reads
73//
74// -- Writing --
75// 1) Writing with full random access
76// 2) Writing with 1-pass sequential access
77// -> We must take care not to require updating values that have already
78// been written. We place the variable-length index at the end and do
79// not put any indicies into the header to fulfill this constraint.
80
81// The model.json, which contains all the metadata information,
82// should be written as the last file. One reason is that the size of tensor
83// data is usually stable. As long as the shape and type of the tensor do not
84// change, the size of the data won't change. On the other sied, the size of the
85// serialized model is likely to change, so we store it as the last record, and
86// we don't need to move previous records when updating the model data.
87
88// The zip format is sufficiently flexible to handle the above use-case.
89// it puts its central directory at the end of the archive and we write
90// model.json as the last file when writing after we have accumulated all
91// other information.
92
93namespace caffe2 {
94namespace serialize {
95
96class TORCH_API PyTorchStreamReader final {
97 public:
98 explicit PyTorchStreamReader(const std::string& file_name);
99 explicit PyTorchStreamReader(std::istream* in);
100 explicit PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in);
101
102 // return dataptr, size
103 std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
104 size_t getRecordOffset(const std::string& name);
105 bool hasRecord(const std::string& name);
106 std::vector<std::string> getAllRecords();
107
108 ~PyTorchStreamReader();
109 uint64_t version() const {
110 return version_;
111 }
112
113 void setShouldLoadDebugSymbol(bool should_load_debug_symbol) {
114 load_debug_symbol_ = should_load_debug_symbol;
115 }
116
117 private:
118 void init();
119 size_t read(uint64_t pos, char* buf, size_t n);
120 void valid(const char* what, const char* info = "");
121 size_t getRecordID(const std::string& name);
122
123 friend size_t
124 istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n);
125 std::unique_ptr<mz_zip_archive> ar_;
126 std::string archive_name_;
127 std::string archive_name_plus_slash_;
128 std::shared_ptr<ReadAdapterInterface> in_;
129 int64_t version_;
130 std::mutex reader_lock_;
131 bool load_debug_symbol_ = true;
132};
133
134class TORCH_API PyTorchStreamWriter final {
135 public:
136 explicit PyTorchStreamWriter(std::string archive_name);
137 explicit PyTorchStreamWriter(
138 const std::function<size_t(const void*, size_t)> writer_func);
139
140 void setMinVersion(const uint64_t version);
141
142 void writeRecord(
143 const std::string& name,
144 const void* data,
145 size_t size,
146 bool compress = false);
147 void writeEndOfFile();
148
149 const std::unordered_set<std::string>& getAllWrittenRecords();
150
151 bool finalized() const {
152 return finalized_;
153 }
154
155 const std::string& archiveName() {
156 return archive_name_;
157 }
158
159 ~PyTorchStreamWriter();
160
161 private:
162 void setup(const std::string& file_name);
163 void valid(const char* what, const char* info = "");
164 size_t current_pos_ = 0;
165 std::unordered_set<std::string> files_written_;
166 std::unique_ptr<mz_zip_archive> ar_;
167 std::string archive_name_;
168 std::string archive_name_plus_slash_;
169 std::string padding_;
170 std::ofstream file_stream_;
171 std::function<size_t(const void*, size_t)> writer_func_;
172 // This number will be updated when the model has operators
173 // that have valid upgraders.
174 uint64_t version_ = kMinProducedFileFormatVersion;
175 bool finalized_ = false;
176 bool err_seen_ = false;
177 friend size_t ostream_write_func(
178 void* pOpaque,
179 uint64_t file_ofs,
180 const void* pBuf,
181 size_t n);
182};
183
184namespace detail {
185// Writer-specific constants
186constexpr uint64_t kFieldAlignment = 64;
187
188// Returns a record to be appended to the local user extra data entry in order
189// to make data beginning aligned at kFieldAlignment bytes boundary.
190size_t getPadding(
191 size_t cursor,
192 size_t filename_size,
193 size_t size,
194 std::string& padding_buf);
195} // namespace detail
196
197} // namespace serialize
198} // namespace caffe2
199