1 | #include <torch/csrc/python_headers.h> |
2 | #include <system_error> |
3 | |
4 | #include <ATen/ops/from_blob.h> |
5 | #include <c10/core/CPUAllocator.h> |
6 | #include <torch/csrc/THP.h> |
7 | #include <torch/csrc/serialization.h> |
8 | |
9 | template <class io> |
10 | Py_ssize_t doPartialRead(io fildes, void* buf, size_t nbytes); |
11 | |
12 | template <class io> |
13 | Py_ssize_t doPartialWrite(io fildes, void* buf, size_t nbytes); |
14 | |
15 | static Py_ssize_t doPartialPythonReadBuffered( |
16 | PyObject* fildes, |
17 | void* buf, |
18 | size_t nbytes); |
19 | static Py_ssize_t doPartialPythonReadInto( |
20 | PyObject* fildes, |
21 | void* buf, |
22 | size_t nbytes); |
23 | static Py_ssize_t doPartialPythonWrite( |
24 | PyObject* fildes, |
25 | void* buf, |
26 | size_t nbytes); |
27 | |
28 | template <> |
29 | Py_ssize_t doPartialRead<int>(int fildes, void* buf, size_t nbytes) { |
30 | return read(fildes, buf, nbytes); |
31 | } |
32 | |
33 | template <> |
34 | Py_ssize_t doPartialRead<PyObject*>( |
35 | PyObject* fildes, |
36 | void* buf, |
37 | size_t nbytes) { |
38 | // Try to use fildes.readinto() instead of fildes.read() |
39 | // because it is more memory efficient. |
40 | // TODO: Stop calling PyObject_HasAttrString() in a loop on our read loop |
41 | auto has_readinto = PyObject_HasAttrString(fildes, "readinto" ) == 1; |
42 | if (has_readinto) { |
43 | return doPartialPythonReadInto(fildes, buf, nbytes); |
44 | } |
45 | return doPartialPythonReadBuffered(fildes, buf, nbytes); |
46 | } |
47 | |
48 | template <> |
49 | Py_ssize_t doPartialWrite<int>(int fildes, void* buf, size_t nbytes) { |
50 | return write(fildes, buf, nbytes); |
51 | } |
52 | |
53 | template <> |
54 | Py_ssize_t doPartialWrite<PyObject*>( |
55 | PyObject* fildes, |
56 | void* buf, |
57 | size_t nbytes) { |
58 | return doPartialPythonWrite(fildes, buf, nbytes); |
59 | } |
60 | |
61 | static inline bool isUnsupportedOperation() { |
62 | THPObjectPtr io(PyImport_ImportModule("io" )); |
63 | if (!io) |
64 | throw python_error(); |
65 | THPObjectPtr exception(PyObject_GetAttrString(io, "UnsupportedOperation" )); |
66 | if (!exception) |
67 | throw python_error(); |
68 | return PyErr_ExceptionMatches(exception.get()); |
69 | } |
70 | |
71 | // Call Python fildes.read(nbytes) and copy it to buf. |
72 | static inline Py_ssize_t doPartialPythonReadBuffered( |
73 | PyObject* fildes, |
74 | void* buf, |
75 | size_t raw_nbytes) { |
76 | // If we request a large amount of data, f.read() will internally try to |
77 | // allocate a buffer of that size. This is counterproductive, because |
78 | // it's not the buffer we ultimately want to write the data into. Read |
79 | // less than that and avoid allocating too much extra memory. |
80 | // TODO: Maybe 260 KB is a bit small... |
81 | const size_t nbytes = std::min<size_t>(raw_nbytes, 262144u); // 2^18 (~260 KB) |
82 | |
83 | THPObjectPtr r(PyObject_CallMethod(fildes, "read" , "i" , nbytes)); |
84 | if (!r) |
85 | throw python_error(); |
86 | |
87 | auto size = PyBytes_GET_SIZE(r.get()); |
88 | const void* py_buf = PyBytes_AsString(r.get()); |
89 | |
90 | // we read EOF |
91 | if (size == 0) { |
92 | return 0; |
93 | } |
94 | |
95 | // Slurp it into the buffer we actually want |
96 | memcpy(buf, py_buf, size); |
97 | |
98 | return size; |
99 | } |
100 | |
101 | // Either does fildes.readinto(buf) or fildes.write(buf) |
102 | static inline Py_ssize_t doPartialPythonIO( |
103 | PyObject* fildes, |
104 | void* buf, |
105 | size_t nbytes, |
106 | bool is_read) { |
107 | auto rw_flag = is_read ? PyBUF_WRITE : PyBUF_READ; |
108 | THPObjectPtr memview( |
109 | PyMemoryView_FromMemory(reinterpret_cast<char*>(buf), nbytes, rw_flag)); |
110 | if (!memview) |
111 | throw python_error(); |
112 | |
113 | std::string method = "write" ; |
114 | if (is_read) { |
115 | method = "readinto" ; |
116 | } |
117 | THPObjectPtr r( |
118 | PyObject_CallMethod(fildes, method.c_str(), "O" , memview.get())); |
119 | if (r) { |
120 | return PyLong_AsSsize_t(r.get()); |
121 | } |
122 | |
123 | // fildes.readinto can return UnsupportedOperation so fall back to |
124 | // fildes.read. |
125 | if (is_read && isUnsupportedOperation()) { |
126 | PyErr_Clear(); |
127 | return doPartialPythonReadBuffered(fildes, buf, nbytes); |
128 | } |
129 | throw python_error(); |
130 | } |
131 | |
132 | // Call Python fildes.readinto(buf) |
133 | static Py_ssize_t doPartialPythonReadInto( |
134 | PyObject* fildes, |
135 | void* buf, |
136 | size_t nbytes) { |
137 | return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ true); |
138 | } |
139 | |
140 | // Call Python fildes.write(buf) |
141 | static Py_ssize_t doPartialPythonWrite( |
142 | PyObject* fildes, |
143 | void* buf, |
144 | size_t nbytes) { |
145 | return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ false); |
146 | } |
147 | |
148 | // Requires that we read EXACTLY nbytes; fails if we don't. |
149 | template <typename io> |
150 | void doRead(io fildes, void* raw_buf, size_t nbytes) { |
151 | char* buf = static_cast<char*>(raw_buf); |
152 | while (nbytes > 0) { |
153 | errno = 0; // doPartialRead may not set errno |
154 | // we read in 1GB blocks to avoid bugs on Mac OS X Lion |
155 | // see https://github.com/pytorch/pytorch/issues/1031 for more details |
156 | Py_ssize_t r = |
157 | doPartialRead(fildes, buf, std::min<size_t>(nbytes, 1073741824)); |
158 | if (r < 0) { |
159 | int err = errno; |
160 | TORCH_INTERNAL_ASSERT( |
161 | err != 0, "read(): impossible! r < 0, but no errno was set" ); |
162 | TORCH_INTERNAL_ASSERT( |
163 | err != EAGAIN, |
164 | "read(): non-blocking fd " , |
165 | fildes, |
166 | " read EAGAIN; cowardly refusing to spin-wait" ); |
167 | if (err == EINTR) { |
168 | continue; |
169 | } else { |
170 | AT_ERROR("read(): fd " , fildes, " failed with " , strerror(err)); |
171 | } |
172 | } else if (r == 0) { |
173 | break; |
174 | } |
175 | buf += r; |
176 | // This is guaranteed by POSIX, but I just want to be double-sure |
177 | // to not underflow a signed integer. |
178 | AT_ASSERT(static_cast<size_t>(r) <= nbytes); |
179 | nbytes -= r; |
180 | } |
181 | if (nbytes != 0) { |
182 | AT_ERROR( |
183 | "unexpected EOF, expected " , |
184 | nbytes, |
185 | " more bytes. The file might be corrupted." ); |
186 | } |
187 | } |
188 | |
189 | template <typename io> |
190 | void doWrite(io fildes, void* raw_buf, size_t nbytes) { |
191 | char* buf = static_cast<char*>(raw_buf); |
192 | while (nbytes > 0) { |
193 | errno = 0; // doPartialWrite may not set errno |
194 | // we write in 1GB blocks to avoid bugs on Mac OS X Lion |
195 | // see https://github.com/pytorch/pytorch/issues/1031 for more details |
196 | Py_ssize_t r = |
197 | doPartialWrite(fildes, buf, std::min<size_t>(nbytes, 1073741824)); |
198 | if (r < 0) { |
199 | int err = errno; |
200 | TORCH_INTERNAL_ASSERT( |
201 | err != 0, "write(): impossible! r < 0, but no errno was set" ); |
202 | TORCH_INTERNAL_ASSERT( |
203 | err != EAGAIN, |
204 | "write(): non-blocking fd " , |
205 | fildes, |
206 | " read EAGAIN; cowardly refusing to spin-wait" ); |
207 | if (err == EINTR) { |
208 | continue; |
209 | } else { |
210 | AT_ERROR("write(): fd " , fildes, " failed with " , strerror(err)); |
211 | } |
212 | } |
213 | buf += r; |
214 | AT_ASSERT(static_cast<size_t>(r) <= nbytes); |
215 | nbytes -= r; |
216 | } |
217 | } |
218 | |
219 | // save_save is necessary since the old eager format saved storages as |
220 | // [size + data], but the v1.5 eager format removes this since size is saved in |
221 | // the filesize. |
222 | template <class io> |
223 | void THPStorage_writeFileRaw( |
224 | c10::StorageImpl* self, |
225 | io fd, |
226 | bool save_size, |
227 | uint64_t element_size) { |
228 | c10::DeviceGuard guard(self->device()); |
229 | uint8_t* data{}; |
230 | at::Tensor cpu_tensor; |
231 | int64_t size_bytes = self->nbytes(); |
232 | int64_t numel = size_bytes / element_size; |
233 | if (self->device_type() == at::kCPU) { |
234 | data = self->data<uint8_t>(); |
235 | } else { |
236 | // Here we use a tensor.to() to impl D2H for all non-CPU device. |
237 | auto device_tensor = at::from_blob( |
238 | self->data<void>(), |
239 | {size_bytes}, |
240 | {1}, |
241 | NULL, |
242 | at::device(self->device()).dtype(c10::kByte), |
243 | {self->device()}); |
244 | cpu_tensor = device_tensor.to(at::kCPU); |
245 | data = (uint8_t*)cpu_tensor.data_ptr(); |
246 | } |
247 | if (save_size) { |
248 | if (torch::utils::THP_nativeByteOrder() == |
249 | torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) |
250 | doWrite(fd, &numel, sizeof(int64_t)); |
251 | else { |
252 | int64_t nsize{}; // convert big endian cpu to little endian storage |
253 | torch::utils::THP_encodeInt64Buffer( |
254 | (uint8_t*)&nsize, |
255 | (const int64_t*)&numel, |
256 | torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, |
257 | 1); |
258 | doWrite(fd, &nsize, sizeof(int64_t)); |
259 | } |
260 | } |
261 | // fast track for bytes and little endian |
262 | if (element_size == 1 || |
263 | torch::utils::THP_nativeByteOrder() == |
264 | torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) { |
265 | doWrite(fd, data, size_bytes); |
266 | } else { |
267 | int64_t buffer_size = std::min(numel, (int64_t)5000); |
268 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
269 | std::unique_ptr<uint8_t[]> le_buffer( |
270 | new uint8_t[buffer_size * element_size]); |
271 | for (int64_t i = 0; i < numel; i += buffer_size) { |
272 | size_t to_convert = std::min(numel - i, buffer_size); |
273 | // NOLINTNEXTLINE(bugprone-branch-clone) |
274 | if (element_size == 2) { |
275 | torch::utils::THP_encodeInt16Buffer( |
276 | (uint8_t*)le_buffer.get(), |
277 | (const int16_t*)data + i, |
278 | torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, |
279 | to_convert); |
280 | } else if (element_size == 4) { |
281 | torch::utils::THP_encodeInt32Buffer( |
282 | (uint8_t*)le_buffer.get(), |
283 | (const int32_t*)data + i, |
284 | torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, |
285 | to_convert); |
286 | } else if (element_size == 8) { |
287 | torch::utils::THP_encodeInt64Buffer( |
288 | (uint8_t*)le_buffer.get(), |
289 | (const int64_t*)data + i, |
290 | torch::utils::THPByteOrder::THP_LITTLE_ENDIAN, |
291 | to_convert); |
292 | } |
293 | doWrite(fd, le_buffer.get(), to_convert * element_size); |
294 | } |
295 | } |
296 | } |
297 | |
298 | template void THPStorage_writeFileRaw<int>( |
299 | c10::StorageImpl* self, |
300 | int fd, |
301 | bool save_size, |
302 | uint64_t element_size); |
303 | template void THPStorage_writeFileRaw<PyObject*>( |
304 | c10::StorageImpl* self, |
305 | PyObject* fd, |
306 | bool save_size, |
307 | uint64_t element_size); |
308 | |
309 | template <class io> |
310 | c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw( |
311 | io file, |
312 | c10::intrusive_ptr<c10::StorageImpl> storage, |
313 | uint64_t element_size) { |
314 | c10::OptionalDeviceGuard guard; |
315 | if (storage.defined()) { |
316 | guard.reset_device(storage->device()); |
317 | } |
318 | int64_t size{}; |
319 | doRead(file, &size, sizeof(int64_t)); |
320 | if (torch::utils::THP_nativeByteOrder() == |
321 | torch::utils::THPByteOrder::THP_BIG_ENDIAN) { |
322 | int64_t tsize = size; // convert little endian storage to big endian cpu |
323 | torch::utils::THP_decodeInt64Buffer( |
324 | &size, (const uint8_t*)&tsize, torch::utils::THP_nativeByteOrder(), 1); |
325 | } |
326 | int64_t nbytes = element_size * size; |
327 | if (!storage.defined()) { |
328 | storage = c10::make_intrusive<at::StorageImpl>( |
329 | c10::StorageImpl::use_byte_size_t(), |
330 | nbytes, |
331 | c10::GetDefaultCPUAllocator(), |
332 | /*resizable=*/true); |
333 | } else { |
334 | int64_t _storage_nbytes = storage->nbytes(); |
335 | TORCH_CHECK( |
336 | _storage_nbytes == nbytes, |
337 | "storage has wrong byte size: expected %ld got %ld" , |
338 | nbytes, |
339 | _storage_nbytes); |
340 | } |
341 | |
342 | std::unique_ptr<char[]> cpu_data; |
343 | |
344 | uint8_t* data{}; |
345 | if (storage->device_type() == at::kCPU) { |
346 | data = storage->data<uint8_t>(); |
347 | } else { |
348 | cpu_data = std::unique_ptr<char[]>(new char[nbytes]); |
349 | data = (uint8_t*)cpu_data.get(); |
350 | } |
351 | |
352 | // fast track for bytes and little endian |
353 | if (element_size == 1 || |
354 | torch::utils::THP_nativeByteOrder() == |
355 | torch::utils::THPByteOrder::THP_LITTLE_ENDIAN) { |
356 | doRead(file, data, storage->nbytes()); |
357 | } else { |
358 | int64_t buffer_size = std::min(size, (int64_t)5000); |
359 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) |
360 | std::unique_ptr<uint8_t[]> le_buffer( |
361 | new uint8_t[buffer_size * element_size]); |
362 | |
363 | for (int64_t i = 0; i < size; i += buffer_size) { |
364 | size_t to_convert = std::min(size - i, buffer_size); |
365 | doRead(file, le_buffer.get(), element_size * to_convert); |
366 | |
367 | // NOLINTNEXTLINE(bugprone-branch-clone) |
368 | if (element_size == 2) { |
369 | torch::utils::THP_decodeInt16Buffer( |
370 | (int16_t*)data + i, |
371 | le_buffer.get(), |
372 | torch::utils::THP_nativeByteOrder(), |
373 | to_convert); |
374 | } else if (element_size == 4) { |
375 | torch::utils::THP_decodeInt32Buffer( |
376 | (int32_t*)data + i, |
377 | le_buffer.get(), |
378 | torch::utils::THP_nativeByteOrder(), |
379 | to_convert); |
380 | } else if (element_size == 8) { |
381 | torch::utils::THP_decodeInt64Buffer( |
382 | (int64_t*)data + i, |
383 | le_buffer.get(), |
384 | torch::utils::THP_nativeByteOrder(), |
385 | to_convert); |
386 | } |
387 | } |
388 | } |
389 | |
390 | if (storage->device_type() != at::kCPU) { |
391 | // Here we use a tensor.copy_() to impl H2D for all non-CPU device. |
392 | auto cpu_tensor = at::from_blob( |
393 | (void*)data, {nbytes}, at::device(at::kCPU).dtype(c10::kByte)); |
394 | auto device_tensor = at::from_blob( |
395 | storage->data<void>(), |
396 | {nbytes}, |
397 | {1}, |
398 | NULL, |
399 | at::device(storage->device()).dtype(c10::kByte), |
400 | {storage->device()}); |
401 | device_tensor.copy_(cpu_tensor); |
402 | } |
403 | return storage; |
404 | } |
405 | |
406 | template c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw<int>( |
407 | int fd, |
408 | c10::intrusive_ptr<c10::StorageImpl> storage, |
409 | uint64_t element_size); |
410 | template c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw<PyObject*>( |
411 | PyObject* fd, |
412 | c10::intrusive_ptr<c10::StorageImpl> storage, |
413 | uint64_t element_size); |
414 | |