1 | #pragma once |
2 | |
3 | #include <array> |
4 | #include <cerrno> |
5 | #include <cstddef> |
6 | #include <cstring> |
7 | #include <fstream> |
8 | #include <istream> |
9 | #include <memory> |
10 | |
11 | #include <c10/core/CPUAllocator.h> |
12 | #include <c10/core/impl/alloc_cpu.h> |
13 | #include <caffe2/serialize/read_adapter_interface.h> |
14 | |
15 | #if defined(HAVE_MMAP) |
16 | #include <fcntl.h> |
17 | #include <sys/mman.h> |
18 | #include <sys/stat.h> |
19 | #include <sys/types.h> |
20 | #include <unistd.h> |
21 | #endif |
22 | |
23 | /** |
24 | * @file |
25 | * |
26 | * Helpers for identifying file formats when reading serialized data. |
27 | * |
28 | * Note that these functions are declared inline because they will typically |
29 | * only be called from one or two locations per binary. |
30 | */ |
31 | |
32 | namespace torch { |
33 | namespace jit { |
34 | |
35 | /** |
36 | * The format of a file or data stream. |
37 | */ |
38 | enum class FileFormat { |
39 | UnknownFileFormat = 0, |
40 | FlatbufferFileFormat, |
41 | ZipFileFormat, |
42 | }; |
43 | |
44 | /// The size of the buffer to pass to #getFileFormat(), in bytes. |
45 | constexpr size_t = 8; |
46 | constexpr size_t kMaxAlignment = 16; |
47 | |
48 | /** |
49 | * Returns the likely file format based on the magic header bytes in @p header, |
50 | * which should contain the first bytes of a file or data stream. |
51 | */ |
52 | // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) |
53 | static inline FileFormat getFileFormat(const char* data) { |
54 | // The size of magic strings to look for in the buffer. |
55 | static constexpr size_t kMagicSize = 4; |
56 | |
57 | // Bytes 4..7 of a Flatbuffer-encoded file produced by |
58 | // `flatbuffer_serializer.h`. (The first four bytes contain an offset to the |
59 | // actual Flatbuffer data.) |
60 | static constexpr std::array<char, kMagicSize> kFlatbufferMagicString = { |
61 | 'P', 'T', 'M', 'F'}; |
62 | static constexpr size_t kFlatbufferMagicOffset = 4; |
63 | |
64 | // The first four bytes of a ZIP file. |
65 | static constexpr std::array<char, kMagicSize> kZipMagicString = { |
66 | 'P', 'K', '\x03', '\x04'}; |
67 | |
68 | // Note that we check for Flatbuffer magic first. Since the first four bytes |
69 | // of flatbuffer data contain an offset to the root struct, it's theoretically |
70 | // possible to construct a file whose offset looks like the ZIP magic. On the |
71 | // other hand, bytes 4-7 of ZIP files are constrained to a small set of values |
72 | // that do not typically cross into the printable ASCII range, so a ZIP file |
73 | // should never have a header that looks like a Flatbuffer file. |
74 | if (std::memcmp( |
75 | data + kFlatbufferMagicOffset, |
76 | kFlatbufferMagicString.data(), |
77 | kMagicSize) == 0) { |
78 | // Magic header for a binary file containing a Flatbuffer-serialized mobile |
79 | // Module. |
80 | return FileFormat::FlatbufferFileFormat; |
81 | } else if (std::memcmp(data, kZipMagicString.data(), kMagicSize) == 0) { |
82 | // Magic header for a zip file, which we use to store pickled sub-files. |
83 | return FileFormat::ZipFileFormat; |
84 | } |
85 | return FileFormat::UnknownFileFormat; |
86 | } |
87 | |
88 | /** |
89 | * Returns the likely file format based on the magic header bytes of @p data. |
90 | * If the stream position changes while inspecting the data, this function will |
91 | * restore the stream position to its original offset before returning. |
92 | */ |
93 | // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) |
94 | static inline FileFormat getFileFormat(std::istream& data) { |
95 | FileFormat format = FileFormat::UnknownFileFormat; |
96 | std::streampos orig_pos = data.tellg(); |
97 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
98 | std::array<char, kFileFormatHeaderSize> ; |
99 | data.read(header.data(), header.size()); |
100 | if (data.good()) { |
101 | format = getFileFormat(header.data()); |
102 | } |
103 | data.seekg(orig_pos, data.beg); |
104 | return format; |
105 | } |
106 | |
107 | /** |
108 | * Returns the likely file format based on the magic header bytes of the file |
109 | * named @p filename. |
110 | */ |
111 | // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) |
112 | static inline FileFormat getFileFormat(const std::string& filename) { |
113 | std::ifstream data(filename, std::ifstream::binary); |
114 | return getFileFormat(data); |
115 | } |
116 | |
117 | // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) |
118 | static void file_not_found_error() { |
119 | std::stringstream message; |
120 | message << "Error while opening file: " ; |
121 | if (errno == ENOENT) { |
122 | message << "no such file or directory" << std::endl; |
123 | } else { |
124 | message << "error no is: " << errno << std::endl; |
125 | } |
126 | TORCH_CHECK(false, message.str()); |
127 | } |
128 | |
129 | // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) |
130 | static inline std::tuple<std::shared_ptr<char>, size_t> get_file_content( |
131 | const char* filename) { |
132 | #if defined(HAVE_MMAP) |
133 | int fd = open(filename, O_RDONLY); |
134 | if (fd < 0) { |
135 | // failed to open file, chances are it's no such file or directory. |
136 | file_not_found_error(); |
137 | } |
138 | struct stat statbuf {}; |
139 | fstat(fd, &statbuf); |
140 | size_t size = statbuf.st_size; |
141 | void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0); |
142 | close(fd); |
143 | auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); }; |
144 | std::shared_ptr<char> data(reinterpret_cast<char*>(ptr), deleter); |
145 | #else |
146 | FILE* f = fopen(filename, "rb" ); |
147 | if (f == nullptr) { |
148 | file_not_found_error(); |
149 | } |
150 | fseek(f, 0, SEEK_END); |
151 | size_t size = ftell(f); |
152 | fseek(f, 0, SEEK_SET); |
153 | // make sure buffer size is multiple of alignment |
154 | size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment; |
155 | std::shared_ptr<char> data( |
156 | static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu); |
157 | fread(data.get(), size, 1, f); |
158 | fclose(f); |
159 | #endif |
160 | return std::make_tuple(data, size); |
161 | } |
162 | |
163 | // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) |
164 | static inline std::tuple<std::shared_ptr<char>, size_t> get_stream_content( |
165 | std::istream& in) { |
166 | // get size of the stream and reset to orig |
167 | std::streampos orig_pos = in.tellg(); |
168 | in.seekg(orig_pos, std::ios::end); |
169 | const long size = in.tellg(); |
170 | in.seekg(orig_pos, in.beg); |
171 | |
172 | // read stream |
173 | // NOLINT make sure buffer size is multiple of alignment |
174 | size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment; |
175 | std::shared_ptr<char> data( |
176 | static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu); |
177 | in.read(data.get(), size); |
178 | |
179 | // reset stream to original position |
180 | in.seekg(orig_pos, in.beg); |
181 | return std::make_tuple(data, size); |
182 | } |
183 | |
184 | // NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration) |
185 | static inline std::tuple<std::shared_ptr<char>, size_t> get_rai_content( |
186 | caffe2::serialize::ReadAdapterInterface* rai) { |
187 | size_t buffer_size = (rai->size() / kMaxAlignment + 1) * kMaxAlignment; |
188 | std::shared_ptr<char> data( |
189 | static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu); |
190 | rai->read( |
191 | 0, data.get(), rai->size(), "Loading ReadAdapterInterface to bytes" ); |
192 | return std::make_tuple(data, buffer_size); |
193 | } |
194 | |
195 | } // namespace jit |
196 | } // namespace torch |
197 | |