1#pragma once
2
3#include <array>
4#include <cerrno>
5#include <cstddef>
6#include <cstring>
7#include <fstream>
8#include <istream>
9#include <memory>
10
11#include <c10/core/CPUAllocator.h>
12#include <c10/core/impl/alloc_cpu.h>
13#include <caffe2/serialize/read_adapter_interface.h>
14
15#if defined(HAVE_MMAP)
16#include <fcntl.h>
17#include <sys/mman.h>
18#include <sys/stat.h>
19#include <sys/types.h>
20#include <unistd.h>
21#endif
22
23/**
24 * @file
25 *
26 * Helpers for identifying file formats when reading serialized data.
27 *
28 * Note that these functions are declared inline because they will typically
29 * only be called from one or two locations per binary.
30 */
31
32namespace torch {
33namespace jit {
34
35/**
36 * The format of a file or data stream.
37 */
38enum class FileFormat {
39 UnknownFileFormat = 0,
40 FlatbufferFileFormat,
41 ZipFileFormat,
42};
43
44/// The size of the buffer to pass to #getFileFormat(), in bytes.
45constexpr size_t kFileFormatHeaderSize = 8;
46constexpr size_t kMaxAlignment = 16;
47
48/**
49 * Returns the likely file format based on the magic header bytes in @p header,
50 * which should contain the first bytes of a file or data stream.
51 */
52// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
53static inline FileFormat getFileFormat(const char* data) {
54 // The size of magic strings to look for in the buffer.
55 static constexpr size_t kMagicSize = 4;
56
57 // Bytes 4..7 of a Flatbuffer-encoded file produced by
58 // `flatbuffer_serializer.h`. (The first four bytes contain an offset to the
59 // actual Flatbuffer data.)
60 static constexpr std::array<char, kMagicSize> kFlatbufferMagicString = {
61 'P', 'T', 'M', 'F'};
62 static constexpr size_t kFlatbufferMagicOffset = 4;
63
64 // The first four bytes of a ZIP file.
65 static constexpr std::array<char, kMagicSize> kZipMagicString = {
66 'P', 'K', '\x03', '\x04'};
67
68 // Note that we check for Flatbuffer magic first. Since the first four bytes
69 // of flatbuffer data contain an offset to the root struct, it's theoretically
70 // possible to construct a file whose offset looks like the ZIP magic. On the
71 // other hand, bytes 4-7 of ZIP files are constrained to a small set of values
72 // that do not typically cross into the printable ASCII range, so a ZIP file
73 // should never have a header that looks like a Flatbuffer file.
74 if (std::memcmp(
75 data + kFlatbufferMagicOffset,
76 kFlatbufferMagicString.data(),
77 kMagicSize) == 0) {
78 // Magic header for a binary file containing a Flatbuffer-serialized mobile
79 // Module.
80 return FileFormat::FlatbufferFileFormat;
81 } else if (std::memcmp(data, kZipMagicString.data(), kMagicSize) == 0) {
82 // Magic header for a zip file, which we use to store pickled sub-files.
83 return FileFormat::ZipFileFormat;
84 }
85 return FileFormat::UnknownFileFormat;
86}
87
88/**
89 * Returns the likely file format based on the magic header bytes of @p data.
90 * If the stream position changes while inspecting the data, this function will
91 * restore the stream position to its original offset before returning.
92 */
93// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
94static inline FileFormat getFileFormat(std::istream& data) {
95 FileFormat format = FileFormat::UnknownFileFormat;
96 std::streampos orig_pos = data.tellg();
97 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
98 std::array<char, kFileFormatHeaderSize> header;
99 data.read(header.data(), header.size());
100 if (data.good()) {
101 format = getFileFormat(header.data());
102 }
103 data.seekg(orig_pos, data.beg);
104 return format;
105}
106
107/**
108 * Returns the likely file format based on the magic header bytes of the file
109 * named @p filename.
110 */
111// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
112static inline FileFormat getFileFormat(const std::string& filename) {
113 std::ifstream data(filename, std::ifstream::binary);
114 return getFileFormat(data);
115}
116
117// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
118static void file_not_found_error() {
119 std::stringstream message;
120 message << "Error while opening file: ";
121 if (errno == ENOENT) {
122 message << "no such file or directory" << std::endl;
123 } else {
124 message << "error no is: " << errno << std::endl;
125 }
126 TORCH_CHECK(false, message.str());
127}
128
129// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
130static inline std::tuple<std::shared_ptr<char>, size_t> get_file_content(
131 const char* filename) {
132#if defined(HAVE_MMAP)
133 int fd = open(filename, O_RDONLY);
134 if (fd < 0) {
135 // failed to open file, chances are it's no such file or directory.
136 file_not_found_error();
137 }
138 struct stat statbuf {};
139 fstat(fd, &statbuf);
140 size_t size = statbuf.st_size;
141 void* ptr = mmap(nullptr, statbuf.st_size, PROT_READ, MAP_PRIVATE, fd, 0);
142 close(fd);
143 auto deleter = [statbuf](char* ptr) { munmap(ptr, statbuf.st_size); };
144 std::shared_ptr<char> data(reinterpret_cast<char*>(ptr), deleter);
145#else
146 FILE* f = fopen(filename, "rb");
147 if (f == nullptr) {
148 file_not_found_error();
149 }
150 fseek(f, 0, SEEK_END);
151 size_t size = ftell(f);
152 fseek(f, 0, SEEK_SET);
153 // make sure buffer size is multiple of alignment
154 size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment;
155 std::shared_ptr<char> data(
156 static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
157 fread(data.get(), size, 1, f);
158 fclose(f);
159#endif
160 return std::make_tuple(data, size);
161}
162
163// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
164static inline std::tuple<std::shared_ptr<char>, size_t> get_stream_content(
165 std::istream& in) {
166 // get size of the stream and reset to orig
167 std::streampos orig_pos = in.tellg();
168 in.seekg(orig_pos, std::ios::end);
169 const long size = in.tellg();
170 in.seekg(orig_pos, in.beg);
171
172 // read stream
173 // NOLINT make sure buffer size is multiple of alignment
174 size_t buffer_size = (size / kMaxAlignment + 1) * kMaxAlignment;
175 std::shared_ptr<char> data(
176 static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
177 in.read(data.get(), size);
178
179 // reset stream to original position
180 in.seekg(orig_pos, in.beg);
181 return std::make_tuple(data, size);
182}
183
184// NOLINTNEXTLINE(facebook-hte-NamespaceScopedStaticDeclaration)
185static inline std::tuple<std::shared_ptr<char>, size_t> get_rai_content(
186 caffe2::serialize::ReadAdapterInterface* rai) {
187 size_t buffer_size = (rai->size() / kMaxAlignment + 1) * kMaxAlignment;
188 std::shared_ptr<char> data(
189 static_cast<char*>(c10::alloc_cpu(buffer_size)), c10::free_cpu);
190 rai->read(
191 0, data.get(), rai->size(), "Loading ReadAdapterInterface to bytes");
192 return std::make_tuple(data, buffer_size);
193}
194
195} // namespace jit
196} // namespace torch
197