1#include <torch/csrc/python_headers.h>
2#ifdef _MSC_VER
3#include <c10/util/win32-headers.h>
4#endif
5#include <structmember.h>
6
7#include <c10/core/CPUAllocator.h>
8#include <libshm.h>
9#include <torch/csrc/CudaIPCTypes.h>
10#include <torch/csrc/Device.h>
11#include <torch/csrc/DynamicTypes.h>
12#include <torch/csrc/THP.h>
13#include <torch/csrc/autograd/utils/wrap_outputs.h>
14#include <torch/csrc/copy_utils.h>
15
16#include <c10/util/intrusive_ptr.h>
17#include <fmt/format.h>
18
19#include <torch/csrc/Storage.h>
20#include <torch/csrc/StorageMethods.h>
21
22#include <ATen/ATen.h>
23#include <ATen/MapAllocator.h>
24#include <torch/csrc/utils/pycfunction_helpers.h>
25#include <torch/csrc/utils/python_arg_parser.h>
26#include <torch/csrc/utils/python_numbers.h>
27
28#ifdef USE_CUDA
29#include <ATen/native/cuda/Resize.h>
30#include <cuda_runtime.h>
31#endif
32
33#include <ATen/native/Resize.h>
34
35#ifdef _MSC_VER
36#define LSEEK _lseeki64
37#else
38#define LSEEK lseek
39#endif
40
41static PyObject* THPStorage_nbytes(PyObject* _self, PyObject* noargs) {
42 HANDLE_TH_ERRORS
43 auto self = (THPStorage*)_self;
44 return py::cast(self->cdata->sym_nbytes()).release().ptr();
45 END_HANDLE_TH_ERRORS
46}
47
48static PyObject* THPStorage_dataPtr(PyObject* _self, PyObject* noargs) {
49 HANDLE_TH_ERRORS
50 auto self = (THPStorage*)_self;
51 return PyLong_FromVoidPtr(self->cdata->data<uint8_t>());
52 END_HANDLE_TH_ERRORS
53}
54
55static PyObject* THPStorage_copy_(
56 PyObject* self,
57 PyObject* args,
58 PyObject* kwargs) {
59 HANDLE_TH_ERRORS
60
61 at::Storage self_ = torch::createStorage(self);
62
63 static torch::PythonArgParser parser({
64 "copy_(Storage src, bool? non_blocking=None)",
65 });
66 torch::ParsedArgs<2> parsed_args;
67 auto r = parser.parse(args, kwargs, parsed_args);
68
69 at::Storage src = r.storage(0);
70 bool non_blocking = r.toBoolOptional(1).value_or(false);
71
72 TORCH_CHECK(self_.nbytes() == src.nbytes(), "size does not match");
73
74 storage_copy(self_, src, non_blocking);
75
76 Py_INCREF(self);
77 return self;
78
79 END_HANDLE_TH_ERRORS
80}
81
82static PyObject* THPStorage_isPinned(PyObject* _self, PyObject* noargs) {
83 HANDLE_TH_ERRORS
84#if defined(USE_CUDA)
85 auto self = (THPStorage*)_self;
86 return PyBool_FromLong(
87 at::globalContext().isPinnedPtr(self->cdata->data<uint8_t>()));
88#else
89 Py_RETURN_FALSE;
90#endif
91 END_HANDLE_TH_ERRORS
92}
93
94static PyObject* THPStorage_elementSize(PyObject* _self, PyObject* noargs) {
95 HANDLE_TH_ERRORS
96 return THPUtils_packInt64(sizeof(uint8_t));
97 END_HANDLE_TH_ERRORS
98}
99
100static PyObject* THPStorage_new(PyObject* _self, PyObject* noargs) {
101 HANDLE_TH_ERRORS
102 auto self = (THPStorage*)_self;
103 c10::Allocator* allocator = self->cdata->allocator();
104 auto new_storage = c10::make_intrusive<at::StorageImpl>(
105 c10::StorageImpl::use_byte_size_t(),
106 0,
107 allocator,
108 /*resizable=*/true);
109
110 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
111 return THPStorage_New(std::move(new_storage));
112 END_HANDLE_TH_ERRORS
113}
114
115static PyObject* THPStorage_resize_(PyObject* _self, PyObject* number_arg) {
116 HANDLE_TH_ERRORS
117 auto self = (THPStorage*)_self;
118 THPUtils_assert(
119 THPUtils_checkLong(number_arg),
120 "resize_ expects an int, "
121 "but got %s",
122 THPUtils_typename(number_arg));
123 int64_t newsize = THPUtils_unpackLong(number_arg);
124 c10::DeviceType device_type = self->cdata->device_type();
125 if (device_type == at::kCPU) {
126 at::native::resize_bytes_cpu(self->cdata, newsize);
127#ifdef USE_CUDA
128 } else if (device_type == at::kCUDA) {
129 ptrdiff_t size_bytes_i = newsize;
130 TORCH_CHECK(
131 !c10::overflows<size_t>(size_bytes_i),
132 "Requested storage size (",
133 size_bytes_i,
134 ") cannot be represented as a size_t");
135 const auto size_bytes = static_cast<size_t>(size_bytes_i);
136 at::native::resize_bytes_cuda(self->cdata, size_bytes);
137#endif
138 } else {
139 TORCH_CHECK(
140 false,
141 "UntypedStorage.resize_: got unexpected device type ",
142 device_type);
143 }
144 Py_INCREF(self);
145 return (PyObject*)self;
146 END_HANDLE_TH_ERRORS
147}
148
149static PyObject* THPStorage_fill_(PyObject* _self, PyObject* number_arg) {
150 HANDLE_TH_ERRORS
151 auto self = (THPStorage*)_self;
152 THPUtils_assert(
153 THPByteUtils_checkReal(number_arg),
154 "fill_ expects int, "
155 "but got %s",
156 THPUtils_typename(number_arg));
157 storage_fill(
158 at::unsafeStorageFromTH(self->cdata, /*retain=*/true),
159 THPByteUtils_unpackReal(number_arg));
160 Py_INCREF(self);
161 return (PyObject*)self;
162 END_HANDLE_TH_ERRORS
163}
164
165static PyObject* THPStorage_fromBuffer(
166 PyObject* _unused,
167 PyObject* args,
168 PyObject* keywds) {
169 HANDLE_TH_ERRORS
170 PyObject* obj = nullptr;
171 const char* byte_order_str = nullptr;
172 Py_ssize_t count = -1, offset = 0;
173 PyObject* dtype_obj = nullptr;
174 c10::ScalarType scalar_type = at::kByte;
175 Py_buffer buffer = {};
176 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
177 constexpr char* kwlist[] = {
178 "buffer", "byte_order", "count", "offset", "dtype", nullptr};
179 constexpr char* argtypes = "O|snnO";
180
181 if (!PyArg_ParseTupleAndKeywords(
182 args,
183 keywds,
184 argtypes,
185 const_cast<char**>(kwlist),
186 &obj,
187 &byte_order_str,
188 &count,
189 &offset,
190 &dtype_obj)) {
191 return nullptr;
192 }
193 TORCH_CHECK(dtype_obj != nullptr, "argument 'dtype' cannot be None");
194 TORCH_CHECK(
195 THPDtype_Check(dtype_obj),
196 "argument 'dtype' must be of type torch.dtype");
197 auto dtype = reinterpret_cast<THPDtype*>(dtype_obj);
198 scalar_type = dtype->scalar_type;
199
200 TORCH_CHECK(
201 (scalar_type == at::kByte) || (scalar_type == at::kChar) ||
202 (byte_order_str != nullptr),
203 "function missing required argument 'byte_order' (pos 2)");
204 size_t element_size = c10::elementSize(scalar_type);
205
206 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
207 bool do_byte_swap;
208 if (scalar_type != at::kByte && scalar_type != at::kChar) {
209 if (strcmp(byte_order_str, "native") == 0) {
210 do_byte_swap = false;
211 } else if (strcmp(byte_order_str, "big") == 0) {
212 do_byte_swap =
213 (torch::utils::THP_LITTLE_ENDIAN ==
214 torch::utils::THP_nativeByteOrder());
215 } else if (strcmp(byte_order_str, "little") == 0) {
216 do_byte_swap =
217 (torch::utils::THP_BIG_ENDIAN == torch::utils::THP_nativeByteOrder());
218 } else {
219 PyErr_Format(
220 PyExc_ValueError,
221 "invalid byte_order '%s' (expected 'big', 'little', or 'native')",
222 byte_order_str);
223 return nullptr;
224 }
225 }
226
227 if (PyObject_GetBuffer(obj, &buffer, PyBUF_SIMPLE) < 0)
228 return nullptr;
229
230 if (offset < 0 || offset > buffer.len) {
231 PyErr_SetString(
232 PyExc_ValueError,
233 fmt::format(
234 "offset must be non-negative and no greater than buffer length ({}) , but got {}",
235 offset,
236 buffer.len));
237 PyBuffer_Release(&buffer);
238 return nullptr;
239 }
240
241 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
242 size_t size_bytes;
243 if (count < 0) {
244 if ((buffer.len - offset) % element_size != 0) {
245 PyErr_SetString(
246 PyExc_ValueError,
247 fmt::format(
248 "buffer size ({}) must be a multiple of element size ({})",
249 buffer.len,
250 element_size));
251 PyBuffer_Release(&buffer);
252 return nullptr;
253 }
254 size_bytes = buffer.len - offset;
255 count = size_bytes / element_size;
256 } else {
257 size_bytes = count * element_size;
258 }
259
260 if (offset + (count * (Py_ssize_t)element_size) > buffer.len) {
261 PyErr_SetString(
262 PyExc_ValueError,
263 fmt::format(
264 "buffer has only {} elements after offset {}, but specified a size of {}",
265 buffer.len - offset,
266 offset,
267 count));
268 PyBuffer_Release(&buffer);
269 return nullptr;
270 }
271
272 uint8_t* src = (uint8_t*)buffer.buf;
273 auto storage = c10::make_intrusive<at::StorageImpl>(
274 c10::StorageImpl::use_byte_size_t(),
275 size_bytes,
276 c10::GetDefaultCPUAllocator(),
277 /*resizable=*/true);
278
279 if (scalar_type == at::kByte || scalar_type == at::kChar) {
280 memcpy(storage->data(), src + offset, count);
281 } else if (scalar_type == at::kBool) {
282 // Because of ASAN checks, that are failing whenever
283 // we are trying to get a value which is not 0 or 1, we have to manually
284 // convert original values to boolean ones.
285 torch::utils::THP_decodeBoolBuffer(
286 storage->data<bool>(), src + offset, do_byte_swap, count);
287 } else if (scalar_type == at::kShort) {
288 torch::utils::THP_decodeInt16Buffer(
289 storage->data<int16_t>(), src + offset, do_byte_swap, count);
290 } else if (scalar_type == at::kInt) {
291 torch::utils::THP_decodeInt32Buffer(
292 storage->data<int32_t>(), src + offset, do_byte_swap, count);
293 } else if (scalar_type == at::kLong) {
294 torch::utils::THP_decodeInt64Buffer(
295 storage->data<int64_t>(), src + offset, do_byte_swap, count);
296 } else if (scalar_type == at::kHalf) {
297 torch::utils::THP_decodeHalfBuffer(
298 storage->data<c10::Half>(), src + offset, do_byte_swap, count);
299 } else if (scalar_type == at::kBFloat16) {
300 torch::utils::THP_decodeBFloat16Buffer(
301 storage->data<c10::BFloat16>(), src + offset, do_byte_swap, count);
302 } else if (scalar_type == at::kFloat) {
303 torch::utils::THP_decodeFloatBuffer(
304 storage->data<float>(), src + offset, do_byte_swap, count);
305 } else if (scalar_type == at::kDouble) {
306 torch::utils::THP_decodeDoubleBuffer(
307 storage->data<double>(), src + offset, do_byte_swap, count);
308 } else if (scalar_type == at::kComplexFloat) {
309 torch::utils::THP_decodeComplexFloatBuffer(
310 storage->data<c10::complex<float>>(),
311 src + offset,
312 do_byte_swap,
313 count);
314 } else if (scalar_type == at::kComplexDouble) {
315 torch::utils::THP_decodeComplexDoubleBuffer(
316 storage->data<c10::complex<double>>(),
317 src + offset,
318 do_byte_swap,
319 count);
320 } else {
321 TORCH_CHECK(false, "Unknown type: ", scalar_type);
322 }
323
324 PyBuffer_Release(&buffer);
325 return (PyObject*)THPStorage_New(storage);
326 END_HANDLE_TH_ERRORS
327}
328
329static PyObject* THPStorage_fromFile(
330 PyObject* _unused,
331 PyObject* args,
332 PyObject* keywds) {
333 HANDLE_TH_ERRORS
334 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
335 const char* filename;
336 Py_ssize_t nbytes = 0;
337 int shared = 0;
338 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
339 constexpr char* kwlist[] = {"filename", "shared", "nbytes", nullptr};
340 if (!PyArg_ParseTupleAndKeywords(
341 args,
342 keywds,
343 "s|in",
344 const_cast<char**>(kwlist),
345 &filename,
346 &shared,
347 &nbytes)) {
348 return nullptr;
349 }
350 if (shared)
351 shared = at::ALLOCATOR_MAPPED_SHARED;
352
353 size_t actual_nbytes = -1;
354 auto storage = c10::make_intrusive<at::StorageImpl>(
355 c10::StorageImpl::use_byte_size_t(),
356 nbytes,
357 at::MapAllocator::makeDataPtr(filename, shared, nbytes, &actual_nbytes),
358 /*allocator=*/nullptr,
359 /*resizable=*/false);
360
361 if (nbytes <= 0) {
362 storage->set_nbytes(actual_nbytes);
363 }
364
365 return (PyObject*)THPStorage_New(std::move(storage));
366 END_HANDLE_TH_ERRORS
367}
368
369PyObject* THPStorage_writeFile(PyObject* _self, PyObject* args) {
370 HANDLE_TH_ERRORS
371 auto self = (THPStorage*)_self;
372 PyObject* file = PyTuple_GetItem(args, 0);
373 bool is_real_file = PyTuple_GetItem(args, 1) == Py_True;
374 bool save_size = PyTuple_GetItem(args, 2) == Py_True;
375 PyObject* element_size_obj = PyTuple_GET_ITEM(args, 3);
376
377 THPUtils_assert(
378 element_size_obj != Py_None, "_write_file: need to specify element size");
379 uint64_t element_size = THPUtils_unpackUInt64(element_size_obj);
380
381 if (!is_real_file) {
382 THPStorage_writeFileRaw<PyObject*>(
383 self->cdata, file, save_size, element_size);
384 Py_RETURN_NONE;
385 }
386
387 int fd = PyObject_AsFileDescriptor(file);
388 THPUtils_assert(
389 fd != -1,
390 "_write_file couldn't retrieve a file descriptor "
391 "from given object");
392 THPStorage_writeFileRaw(self->cdata, fd, save_size, element_size);
393 Py_RETURN_NONE;
394 END_HANDLE_TH_ERRORS
395}
396
397PyObject* THPStorage_newWithFile(PyObject* _unused, PyObject* args) {
398 HANDLE_TH_ERRORS
399 TORCH_CHECK(
400 PyTuple_Size(args) == 2, "_new_with_file takes exactly two arguments");
401 int fd = PyObject_AsFileDescriptor(PyTuple_GetItem(args, 0));
402 THPUtils_assert(
403 fd != -1,
404 "_new_with_file couldn't retrieve a file "
405 "descriptor from given object");
406 PyObject* element_size_obj = PyTuple_GET_ITEM(args, 1);
407 THPUtils_assert(
408 element_size_obj != Py_None,
409 "_new_with_file: need to specify element size");
410 uint64_t element_size = THPUtils_unpackUInt64(element_size_obj);
411
412 auto storage = THPStorage_readFileRaw<int>(fd, {}, element_size);
413 if (!storage.defined())
414 return nullptr;
415 return THPStorage_New(std::move(storage));
416 END_HANDLE_TH_ERRORS
417}
418
419static PyObject* THPStorage_setFromFile(PyObject* _self, PyObject* args) {
420 HANDLE_TH_ERRORS
421 auto self = (THPStorage*)_self;
422 PyObject* file = PyTuple_GET_ITEM(args, 0);
423 PyObject* offset = PyTuple_GET_ITEM(args, 1);
424 bool is_real_file = PyTuple_GET_ITEM(args, 2) == Py_True;
425
426 PyObject* element_size_obj = PyTuple_GET_ITEM(args, 3);
427
428 THPUtils_assert(
429 element_size_obj != Py_None,
430 "_set_from_file: need to specify element size");
431 uint64_t element_size = THPUtils_unpackUInt64(element_size_obj);
432
433 if (!is_real_file) {
434 // offset can be implemented with a call to the Python object's seek()
435 // but it is currently unnecessary to support this.
436 THPUtils_assert(
437 offset == Py_None,
438 "_set_from_file: offset is NYI for filelike objects");
439
440 auto self_storage =
441 c10::intrusive_ptr<c10::StorageImpl>::reclaim_copy(self->cdata);
442 auto storage = THPStorage_readFileRaw<PyObject*>(
443 file, std::move(self_storage), element_size);
444 if (!storage.defined()) {
445 return nullptr;
446 }
447 Py_INCREF(self);
448 return (PyObject*)self;
449 }
450
451 // file is backed by a fd
452 const int fd = PyObject_AsFileDescriptor(file);
453 const auto fd_original_pos = LSEEK(fd, 0, SEEK_CUR);
454 if (offset != Py_None) {
455 LSEEK(fd, THPUtils_unpackLong(offset), SEEK_SET);
456 }
457 THPUtils_assert(
458 fd != -1,
459 "_set_from_file couldn't retrieve a file "
460 "descriptor from given object");
461 auto self_storage =
462 c10::intrusive_ptr<c10::StorageImpl>::reclaim_copy(self->cdata);
463 auto storage = THPStorage_readFileRaw<int>(fd, self_storage, element_size);
464 if (!storage.defined())
465 return nullptr;
466 Py_INCREF(self);
467
468 // the file descriptor is returned to original position and
469 // the file handle at python call-site needs updating to the
470 // advanced position
471 const auto fd_current_pos = LSEEK(fd, 0, SEEK_CUR);
472 LSEEK(fd, fd_original_pos, SEEK_SET);
473 const auto seek_return =
474 PyObject_CallMethod(file, "seek", "Li", (long long)fd_current_pos, 0);
475 if (seek_return == nullptr) {
476 return nullptr;
477 }
478 Py_DECREF(seek_return);
479
480 return (PyObject*)self;
481 END_HANDLE_TH_ERRORS
482}
483
484PyObject* THPStorage__setCdata(PyObject* _self, PyObject* new_cdata) {
485 HANDLE_TH_ERRORS
486 auto self = (THPStorage*)_self;
487 THPUtils_assert(
488 THPUtils_checkLong(new_cdata),
489 "given an invalid argument to "
490 "_set_cdata - expected an int or long, but got %s",
491 THPUtils_typename(new_cdata));
492 c10::StorageImpl* ptr = (c10::StorageImpl*)PyLong_AsVoidPtr(new_cdata);
493 if (ptr) {
494 c10::raw::intrusive_ptr::incref(ptr);
495 }
496 if (self->cdata) {
497 c10::raw::intrusive_ptr::decref(self->cdata);
498 }
499 self->cdata = ptr;
500 Py_INCREF(self);
501 return (PyObject*)self;
502 END_HANDLE_TH_ERRORS
503}
504
505// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
506static PyMethodDef THPStorage_methods[] = {
507 {"copy_",
508 castPyCFunctionWithKeywords(THPStorage_copy_),
509 METH_VARARGS | METH_KEYWORDS,
510 nullptr},
511 {"element_size", THPStorage_elementSize, METH_NOARGS, nullptr},
512 {"fill_", THPStorage_fill_, METH_O, nullptr},
513 {"new", THPStorage_new, METH_NOARGS, nullptr},
514 {"resize_", THPStorage_resize_, METH_O, nullptr},
515 {"nbytes", THPStorage_nbytes, METH_NOARGS, nullptr},
516 {"data_ptr", THPStorage_dataPtr, METH_NOARGS, nullptr},
517 {"is_pinned", THPStorage_isPinned, METH_NOARGS, nullptr},
518 {"_write_file", THPStorage_writeFile, METH_VARARGS, nullptr},
519 {"_new_with_file",
520 THPStorage_newWithFile,
521 METH_VARARGS | METH_STATIC,
522 nullptr},
523 {"_set_from_file", THPStorage_setFromFile, METH_VARARGS, nullptr},
524 {"from_buffer",
525 castPyCFunctionWithKeywords(THPStorage_fromBuffer),
526 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
527 nullptr},
528 {"from_file",
529 castPyCFunctionWithKeywords(THPStorage_fromFile),
530 METH_VARARGS | METH_KEYWORDS | METH_STATIC,
531 nullptr},
532 {"_set_cdata", THPStorage__setCdata, METH_O, nullptr},
533 {nullptr}};
534
535PyMethodDef* THPStorage_getMethods() {
536 return THPStorage_methods;
537}
538