1 | #pragma once |
2 | |
3 | #include <cerrno> |
4 | #include <cstdio> |
5 | #include <cstring> |
6 | #include <fstream> |
7 | #include <istream> |
8 | #include <mutex> |
9 | #include <ostream> |
10 | #include <unordered_set> |
11 | |
12 | #include <c10/core/Allocator.h> |
13 | #include <c10/core/Backend.h> |
14 | |
15 | #include "caffe2/serialize/istream_adapter.h" |
16 | #include "caffe2/serialize/read_adapter_interface.h" |
17 | #include "caffe2/serialize/versions.h" |
18 | |
19 | extern "C" { |
20 | typedef struct mz_zip_archive mz_zip_archive; |
21 | } |
22 | |
23 | // PyTorch containers are a special zip archive with the following layout |
24 | // archive_name.zip contains: |
25 | // archive_name/ |
26 | // version # a file with a single decimal number written in ascii, |
27 | // # used to establish the version of the archive format |
28 | // model.json # overall model description, this is a json output of |
29 | // # ModelDef from torch.proto |
30 | // # the following names are by convention only, model.json will |
31 | // # refer to these files by full names |
32 | // tensors/ |
33 | // 0 # flat storage for tensor data, meta-data about shapes, etc. is |
34 | // # in model.json |
35 | // 1 |
36 | // ... |
37 | // # code entries will only exist for modules that have methods attached |
38 | // code/ |
39 | // archive_name.py # serialized torch script code (python syntax, using |
40 | // PythonPrint) archive_name_my_submodule.py # submodules have separate |
41 | // files |
42 | // |
43 | // The PyTorchStreamWriter also ensures additional useful properties for these |
44 | // files |
45 | // 1. All files are stored uncompressed. |
46 | // 2. All files in the archive are aligned to 64 byte boundaries such that |
47 | // it is possible to mmap the entire file and get an aligned pointer to |
48 | // tensor data. |
49 | // 3. We universally write in ZIP64 format for consistency. |
50 | |
51 | // The PyTorchStreamReader also provides additional properties: |
52 | // 1. It can read zip files that are created with common |
53 | // zip tools. This means that even though our writer doesn't compress files, |
54 | // the reader can still read files that were compressed. |
55 | // 2. It provides a getRecordOffset function which returns the offset into the |
56 | // raw file where file data lives. If the file was written with |
57 | // PyTorchStreamWriter it is guaranteed to be 64 byte aligned. |
58 | |
59 | // PyTorchReader/Writer handle checking the version number on the archive format |
60 | // and ensure that all files are written to a archive_name directory so they |
61 | // unzip cleanly. |
62 | |
63 | // When developing this format we want to pay particular attention to the |
64 | // following use cases: |
65 | // |
66 | // -- Reading -- |
67 | // 1) Reading with full random access |
68 | // a) Reading with file api's such as fread() |
69 | // b) mmaping the file and jumping around the mapped region |
70 | // 2) Reading with 1-pass sequential access |
71 | // -> A reader will need to build up a data structure of parsed structures |
72 | // as it reads |
73 | // |
74 | // -- Writing -- |
75 | // 1) Writing with full random access |
76 | // 2) Writing with 1-pass sequential access |
77 | // -> We must take care not to require updating values that have already |
78 | // been written. We place the variable-length index at the end and do |
79 | // not put any indicies into the header to fulfill this constraint. |
80 | |
81 | // The model.json, which contains all the metadata information, |
82 | // should be written as the last file. One reason is that the size of tensor |
83 | // data is usually stable. As long as the shape and type of the tensor do not |
84 | // change, the size of the data won't change. On the other sied, the size of the |
85 | // serialized model is likely to change, so we store it as the last record, and |
86 | // we don't need to move previous records when updating the model data. |
87 | |
88 | // The zip format is sufficiently flexible to handle the above use-case. |
89 | // it puts its central directory at the end of the archive and we write |
90 | // model.json as the last file when writing after we have accumulated all |
91 | // other information. |
92 | |
93 | namespace caffe2 { |
94 | namespace serialize { |
95 | |
96 | class TORCH_API PyTorchStreamReader final { |
97 | public: |
98 | explicit PyTorchStreamReader(const std::string& file_name); |
99 | explicit PyTorchStreamReader(std::istream* in); |
100 | explicit PyTorchStreamReader(std::shared_ptr<ReadAdapterInterface> in); |
101 | |
102 | // return dataptr, size |
103 | std::tuple<at::DataPtr, size_t> getRecord(const std::string& name); |
104 | size_t getRecordOffset(const std::string& name); |
105 | bool hasRecord(const std::string& name); |
106 | std::vector<std::string> getAllRecords(); |
107 | |
108 | ~PyTorchStreamReader(); |
109 | uint64_t version() const { |
110 | return version_; |
111 | } |
112 | |
113 | void setShouldLoadDebugSymbol(bool should_load_debug_symbol) { |
114 | load_debug_symbol_ = should_load_debug_symbol; |
115 | } |
116 | |
117 | private: |
118 | void init(); |
119 | size_t read(uint64_t pos, char* buf, size_t n); |
120 | void valid(const char* what, const char* info = "" ); |
121 | size_t getRecordID(const std::string& name); |
122 | |
123 | friend size_t |
124 | istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n); |
125 | std::unique_ptr<mz_zip_archive> ar_; |
126 | std::string archive_name_; |
127 | std::string archive_name_plus_slash_; |
128 | std::shared_ptr<ReadAdapterInterface> in_; |
129 | int64_t version_; |
130 | std::mutex reader_lock_; |
131 | bool load_debug_symbol_ = true; |
132 | }; |
133 | |
134 | class TORCH_API PyTorchStreamWriter final { |
135 | public: |
136 | explicit PyTorchStreamWriter(std::string archive_name); |
137 | explicit PyTorchStreamWriter( |
138 | const std::function<size_t(const void*, size_t)> writer_func); |
139 | |
140 | void setMinVersion(const uint64_t version); |
141 | |
142 | void writeRecord( |
143 | const std::string& name, |
144 | const void* data, |
145 | size_t size, |
146 | bool compress = false); |
147 | void writeEndOfFile(); |
148 | |
149 | const std::unordered_set<std::string>& getAllWrittenRecords(); |
150 | |
151 | bool finalized() const { |
152 | return finalized_; |
153 | } |
154 | |
155 | const std::string& archiveName() { |
156 | return archive_name_; |
157 | } |
158 | |
159 | ~PyTorchStreamWriter(); |
160 | |
161 | private: |
162 | void setup(const std::string& file_name); |
163 | void valid(const char* what, const char* info = "" ); |
164 | size_t current_pos_ = 0; |
165 | std::unordered_set<std::string> files_written_; |
166 | std::unique_ptr<mz_zip_archive> ar_; |
167 | std::string archive_name_; |
168 | std::string archive_name_plus_slash_; |
169 | std::string padding_; |
170 | std::ofstream file_stream_; |
171 | std::function<size_t(const void*, size_t)> writer_func_; |
172 | // This number will be updated when the model has operators |
173 | // that have valid upgraders. |
174 | uint64_t version_ = kMinProducedFileFormatVersion; |
175 | bool finalized_ = false; |
176 | bool err_seen_ = false; |
177 | friend size_t ostream_write_func( |
178 | void* pOpaque, |
179 | uint64_t file_ofs, |
180 | const void* pBuf, |
181 | size_t n); |
182 | }; |
183 | |
184 | namespace detail { |
185 | // Writer-specific constants |
186 | constexpr uint64_t kFieldAlignment = 64; |
187 | |
188 | // Returns a record to be appended to the local user extra data entry in order |
189 | // to make data beginning aligned at kFieldAlignment bytes boundary. |
190 | size_t getPadding( |
191 | size_t cursor, |
192 | size_t filename_size, |
193 | size_t size, |
194 | std::string& padding_buf); |
195 | } // namespace detail |
196 | |
197 | } // namespace serialize |
198 | } // namespace caffe2 |
199 | |