1#include <torch/csrc/python_headers.h>
2#include <system_error>
3
4#include <ATen/ops/from_blob.h>
5#include <c10/core/CPUAllocator.h>
6#include <torch/csrc/THP.h>
7#include <torch/csrc/serialization.h>
8
9template <class io>
10Py_ssize_t doPartialRead(io fildes, void* buf, size_t nbytes);
11
12template <class io>
13Py_ssize_t doPartialWrite(io fildes, void* buf, size_t nbytes);
14
15static Py_ssize_t doPartialPythonReadBuffered(
16 PyObject* fildes,
17 void* buf,
18 size_t nbytes);
19static Py_ssize_t doPartialPythonReadInto(
20 PyObject* fildes,
21 void* buf,
22 size_t nbytes);
23static Py_ssize_t doPartialPythonWrite(
24 PyObject* fildes,
25 void* buf,
26 size_t nbytes);
27
28template <>
29Py_ssize_t doPartialRead<int>(int fildes, void* buf, size_t nbytes) {
30 return read(fildes, buf, nbytes);
31}
32
33template <>
34Py_ssize_t doPartialRead<PyObject*>(
35 PyObject* fildes,
36 void* buf,
37 size_t nbytes) {
38 // Try to use fildes.readinto() instead of fildes.read()
39 // because it is more memory efficient.
40 // TODO: Stop calling PyObject_HasAttrString() in a loop on our read loop
41 auto has_readinto = PyObject_HasAttrString(fildes, "readinto") == 1;
42 if (has_readinto) {
43 return doPartialPythonReadInto(fildes, buf, nbytes);
44 }
45 return doPartialPythonReadBuffered(fildes, buf, nbytes);
46}
47
48template <>
49Py_ssize_t doPartialWrite<int>(int fildes, void* buf, size_t nbytes) {
50 return write(fildes, buf, nbytes);
51}
52
53template <>
54Py_ssize_t doPartialWrite<PyObject*>(
55 PyObject* fildes,
56 void* buf,
57 size_t nbytes) {
58 return doPartialPythonWrite(fildes, buf, nbytes);
59}
60
61static inline bool isUnsupportedOperation() {
62 THPObjectPtr io(PyImport_ImportModule("io"));
63 if (!io)
64 throw python_error();
65 THPObjectPtr exception(PyObject_GetAttrString(io, "UnsupportedOperation"));
66 if (!exception)
67 throw python_error();
68 return PyErr_ExceptionMatches(exception.get());
69}
70
71// Call Python fildes.read(nbytes) and copy it to buf.
72static inline Py_ssize_t doPartialPythonReadBuffered(
73 PyObject* fildes,
74 void* buf,
75 size_t raw_nbytes) {
76 // If we request a large amount of data, f.read() will internally try to
77 // allocate a buffer of that size. This is counterproductive, because
78 // it's not the buffer we ultimately want to write the data into. Read
79 // less than that and avoid allocating too much extra memory.
80 // TODO: Maybe 260 KB is a bit small...
81 const size_t nbytes = std::min<size_t>(raw_nbytes, 262144u); // 2^18 (~260 KB)
82
83 THPObjectPtr r(PyObject_CallMethod(fildes, "read", "i", nbytes));
84 if (!r)
85 throw python_error();
86
87 auto size = PyBytes_GET_SIZE(r.get());
88 const void* py_buf = PyBytes_AsString(r.get());
89
90 // we read EOF
91 if (size == 0) {
92 return 0;
93 }
94
95 // Slurp it into the buffer we actually want
96 memcpy(buf, py_buf, size);
97
98 return size;
99}
100
101// Either does fildes.readinto(buf) or fildes.write(buf)
102static inline Py_ssize_t doPartialPythonIO(
103 PyObject* fildes,
104 void* buf,
105 size_t nbytes,
106 bool is_read) {
107 auto rw_flag = is_read ? PyBUF_WRITE : PyBUF_READ;
108 THPObjectPtr memview(
109 PyMemoryView_FromMemory(reinterpret_cast<char*>(buf), nbytes, rw_flag));
110 if (!memview)
111 throw python_error();
112
113 std::string method = "write";
114 if (is_read) {
115 method = "readinto";
116 }
117 THPObjectPtr r(
118 PyObject_CallMethod(fildes, method.c_str(), "O", memview.get()));
119 if (r) {
120 return PyLong_AsSsize_t(r.get());
121 }
122
123 // fildes.readinto can return UnsupportedOperation so fall back to
124 // fildes.read.
125 if (is_read && isUnsupportedOperation()) {
126 PyErr_Clear();
127 return doPartialPythonReadBuffered(fildes, buf, nbytes);
128 }
129 throw python_error();
130}
131
132// Call Python fildes.readinto(buf)
133static Py_ssize_t doPartialPythonReadInto(
134 PyObject* fildes,
135 void* buf,
136 size_t nbytes) {
137 return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ true);
138}
139
140// Call Python fildes.write(buf)
141static Py_ssize_t doPartialPythonWrite(
142 PyObject* fildes,
143 void* buf,
144 size_t nbytes) {
145 return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ false);
146}
147
148// Requires that we read EXACTLY nbytes; fails if we don't.
149template <typename io>
150void doRead(io fildes, void* raw_buf, size_t nbytes) {
151 char* buf = static_cast<char*>(raw_buf);
152 while (nbytes > 0) {
153 errno = 0; // doPartialRead may not set errno
154 // we read in 1GB blocks to avoid bugs on Mac OS X Lion
155 // see https://github.com/pytorch/pytorch/issues/1031 for more details
156 Py_ssize_t r =
157 doPartialRead(fildes, buf, std::min<size_t>(nbytes, 1073741824));
158 if (r < 0) {
159 int err = errno;
160 TORCH_INTERNAL_ASSERT(
161 err != 0, "read(): impossible! r < 0, but no errno was set");
162 TORCH_INTERNAL_ASSERT(
163 err != EAGAIN,
164 "read(): non-blocking fd ",
165 fildes,
166 " read EAGAIN; cowardly refusing to spin-wait");
167 if (err == EINTR) {
168 continue;
169 } else {
170 AT_ERROR("read(): fd ", fildes, " failed with ", strerror(err));
171 }
172 } else if (r == 0) {
173 break;
174 }
175 buf += r;
176 // This is guaranteed by POSIX, but I just want to be double-sure
177 // to not underflow a signed integer.
178 AT_ASSERT(static_cast<size_t>(r) <= nbytes);
179 nbytes -= r;
180 }
181 if (nbytes != 0) {
182 AT_ERROR(
183 "unexpected EOF, expected ",
184 nbytes,
185 " more bytes. The file might be corrupted.");
186 }
187}
188
189template <typename io>
190void doWrite(io fildes, void* raw_buf, size_t nbytes) {
191 char* buf = static_cast<char*>(raw_buf);
192 while (nbytes > 0) {
193 errno = 0; // doPartialWrite may not set errno
194 // we write in 1GB blocks to avoid bugs on Mac OS X Lion
195 // see https://github.com/pytorch/pytorch/issues/1031 for more details
196 Py_ssize_t r =
197 doPartialWrite(fildes, buf, std::min<size_t>(nbytes, 1073741824));
198 if (r < 0) {
199 int err = errno;
200 TORCH_INTERNAL_ASSERT(
201 err != 0, "write(): impossible! r < 0, but no errno was set");
202 TORCH_INTERNAL_ASSERT(
203 err != EAGAIN,
204 "write(): non-blocking fd ",
205 fildes,
206 " read EAGAIN; cowardly refusing to spin-wait");
207 if (err == EINTR) {
208 continue;
209 } else {
210 AT_ERROR("write(): fd ", fildes, " failed with ", strerror(err));
211 }
212 }
213 buf += r;
214 AT_ASSERT(static_cast<size_t>(r) <= nbytes);
215 nbytes -= r;
216 }
217}
218
219// save_save is necessary since the old eager format saved storages as
220// [size + data], but the v1.5 eager format removes this since size is saved in
221// the filesize.
222template <class io>
223void THPStorage_writeFileRaw(
224 c10::StorageImpl* self,
225 io fd,
226 bool save_size,
227 uint64_t element_size) {
228 c10::DeviceGuard guard(self->device());
229 uint8_t* data{};
230 at::Tensor cpu_tensor;
231 int64_t size_bytes = self->nbytes();
232 int64_t numel = size_bytes / element_size;
233 if (self->device_type() == at::kCPU) {
234 data = self->data<uint8_t>();
235 } else {
236 // Here we use a tensor.to() to impl D2H for all non-CPU device.
237 auto device_tensor = at::from_blob(
238 self->data<void>(),
239 {size_bytes},
240 {1},
241 NULL,
242 at::device(self->device()).dtype(c10::kByte),
243 {self->device()});
244 cpu_tensor = device_tensor.to(at::kCPU);
245 data = (uint8_t*)cpu_tensor.data_ptr();
246 }
247 if (save_size) {
248 if (torch::utils::THP_nativeByteOrder() ==
249 torch::utils::THPByteOrder::THP_LITTLE_ENDIAN)
250 doWrite(fd, &numel, sizeof(int64_t));
251 else {
252 int64_t nsize{}; // convert big endian cpu to little endian storage
253 torch::utils::THP_encodeInt64Buffer(
254 (uint8_t*)&nsize,
255 (const int64_t*)&numel,
256 torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
257 1);
258 doWrite(fd, &nsize, sizeof(int64_t));
259 }
260 }
261 // fast track for bytes and little endian
262 if (element_size == 1 ||
263 torch::utils::THP_nativeByteOrder() ==
264 torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) {
265 doWrite(fd, data, size_bytes);
266 } else {
267 int64_t buffer_size = std::min(numel, (int64_t)5000);
268 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
269 std::unique_ptr<uint8_t[]> le_buffer(
270 new uint8_t[buffer_size * element_size]);
271 for (int64_t i = 0; i < numel; i += buffer_size) {
272 size_t to_convert = std::min(numel - i, buffer_size);
273 // NOLINTNEXTLINE(bugprone-branch-clone)
274 if (element_size == 2) {
275 torch::utils::THP_encodeInt16Buffer(
276 (uint8_t*)le_buffer.get(),
277 (const int16_t*)data + i,
278 torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
279 to_convert);
280 } else if (element_size == 4) {
281 torch::utils::THP_encodeInt32Buffer(
282 (uint8_t*)le_buffer.get(),
283 (const int32_t*)data + i,
284 torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
285 to_convert);
286 } else if (element_size == 8) {
287 torch::utils::THP_encodeInt64Buffer(
288 (uint8_t*)le_buffer.get(),
289 (const int64_t*)data + i,
290 torch::utils::THPByteOrder::THP_LITTLE_ENDIAN,
291 to_convert);
292 }
293 doWrite(fd, le_buffer.get(), to_convert * element_size);
294 }
295 }
296}
297
298template void THPStorage_writeFileRaw<int>(
299 c10::StorageImpl* self,
300 int fd,
301 bool save_size,
302 uint64_t element_size);
303template void THPStorage_writeFileRaw<PyObject*>(
304 c10::StorageImpl* self,
305 PyObject* fd,
306 bool save_size,
307 uint64_t element_size);
308
309template <class io>
310c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
311 io file,
312 c10::intrusive_ptr<c10::StorageImpl> storage,
313 uint64_t element_size) {
314 c10::OptionalDeviceGuard guard;
315 if (storage.defined()) {
316 guard.reset_device(storage->device());
317 }
318 int64_t size{};
319 doRead(file, &size, sizeof(int64_t));
320 if (torch::utils::THP_nativeByteOrder() ==
321 torch::utils::THPByteOrder::THP_BIG_ENDIAN) {
322 int64_t tsize = size; // convert little endian storage to big endian cpu
323 torch::utils::THP_decodeInt64Buffer(
324 &size, (const uint8_t*)&tsize, torch::utils::THP_nativeByteOrder(), 1);
325 }
326 int64_t nbytes = element_size * size;
327 if (!storage.defined()) {
328 storage = c10::make_intrusive<at::StorageImpl>(
329 c10::StorageImpl::use_byte_size_t(),
330 nbytes,
331 c10::GetDefaultCPUAllocator(),
332 /*resizable=*/true);
333 } else {
334 int64_t _storage_nbytes = storage->nbytes();
335 TORCH_CHECK(
336 _storage_nbytes == nbytes,
337 "storage has wrong byte size: expected %ld got %ld",
338 nbytes,
339 _storage_nbytes);
340 }
341
342 std::unique_ptr<char[]> cpu_data;
343
344 uint8_t* data{};
345 if (storage->device_type() == at::kCPU) {
346 data = storage->data<uint8_t>();
347 } else {
348 cpu_data = std::unique_ptr<char[]>(new char[nbytes]);
349 data = (uint8_t*)cpu_data.get();
350 }
351
352 // fast track for bytes and little endian
353 if (element_size == 1 ||
354 torch::utils::THP_nativeByteOrder() ==
355 torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) {
356 doRead(file, data, storage->nbytes());
357 } else {
358 int64_t buffer_size = std::min(size, (int64_t)5000);
359 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
360 std::unique_ptr<uint8_t[]> le_buffer(
361 new uint8_t[buffer_size * element_size]);
362
363 for (int64_t i = 0; i < size; i += buffer_size) {
364 size_t to_convert = std::min(size - i, buffer_size);
365 doRead(file, le_buffer.get(), element_size * to_convert);
366
367 // NOLINTNEXTLINE(bugprone-branch-clone)
368 if (element_size == 2) {
369 torch::utils::THP_decodeInt16Buffer(
370 (int16_t*)data + i,
371 le_buffer.get(),
372 torch::utils::THP_nativeByteOrder(),
373 to_convert);
374 } else if (element_size == 4) {
375 torch::utils::THP_decodeInt32Buffer(
376 (int32_t*)data + i,
377 le_buffer.get(),
378 torch::utils::THP_nativeByteOrder(),
379 to_convert);
380 } else if (element_size == 8) {
381 torch::utils::THP_decodeInt64Buffer(
382 (int64_t*)data + i,
383 le_buffer.get(),
384 torch::utils::THP_nativeByteOrder(),
385 to_convert);
386 }
387 }
388 }
389
390 if (storage->device_type() != at::kCPU) {
391 // Here we use a tensor.copy_() to impl H2D for all non-CPU device.
392 auto cpu_tensor = at::from_blob(
393 (void*)data, {nbytes}, at::device(at::kCPU).dtype(c10::kByte));
394 auto device_tensor = at::from_blob(
395 storage->data<void>(),
396 {nbytes},
397 {1},
398 NULL,
399 at::device(storage->device()).dtype(c10::kByte),
400 {storage->device()});
401 device_tensor.copy_(cpu_tensor);
402 }
403 return storage;
404}
405
406template c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw<int>(
407 int fd,
408 c10::intrusive_ptr<c10::StorageImpl> storage,
409 uint64_t element_size);
410template c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw<PyObject*>(
411 PyObject* fd,
412 c10::intrusive_ptr<c10::StorageImpl> storage,
413 uint64_t element_size);
414