1#include <torch/csrc/python_headers.h>
2#ifdef _MSC_VER
3#include <c10/util/win32-headers.h>
4#endif
5#include <structmember.h>
6
7#include <ATen/mps/MPSDevice.h>
8#include <c10/core/CPUAllocator.h>
9#include <libshm.h>
10#include <torch/csrc/CudaIPCTypes.h>
11#include <torch/csrc/Device.h>
12#include <torch/csrc/DynamicTypes.h>
13#include <torch/csrc/StorageMethods.h>
14#include <torch/csrc/StorageSharing.h>
15#include <torch/csrc/THP.h>
16#include <torch/csrc/autograd/utils/wrap_outputs.h>
17#include <torch/csrc/copy_utils.h>
18#include <torch/csrc/utils/python_arg_parser.h>
19
20#include <c10/util/intrusive_ptr.h>
21#include <fmt/format.h>
22
23template <>
24void THPPointer<c10::StorageImpl>::free() {
25 if (ptr) {
26 c10::raw::intrusive_ptr::decref(ptr);
27 }
28}
29
30PyObject* THPStorageClass = nullptr;
31
32PyObject* THPStorage_New(c10::intrusive_ptr<c10::StorageImpl> ptr) {
33 AT_ASSERT(ptr);
34 PyTypeObject* type = (PyTypeObject*)THPStorageClass;
35 PyObject* obj = type->tp_alloc(type, 0);
36 if (obj) {
37 ((THPStorage*)obj)->cdata = ptr.release();
38 }
39 return obj;
40}
41
42static void THPStorage_subclass_dealloc(PyObject* self) {
43 THPStorage* _self = (THPStorage*)self;
44 // Some subclass of StorageBase are GC-tracked objects even
45 // though the base class is not.
46 auto* type = Py_TYPE(self);
47 if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
48 PyObject_GC_UnTrack(self);
49 }
50 if (_self->cdata) {
51 c10::raw::intrusive_ptr::decref(_self->cdata);
52 }
53 Py_TYPE(_self)->tp_free(self);
54}
55
56static PyObject* THPStorage_pynew(
57 PyTypeObject* type,
58 PyObject* args,
59 PyObject* kwargs) {
60 HANDLE_TH_ERRORS
61 TORCH_CHECK(
62 type != &THPStorageType,
63 "Cannot directly construct StorageBase; subclass it and then construct that");
64 static torch::PythonArgParser parser({
65 THPStorageStr "(*, int64_t allocator=None, Device device=None)",
66 THPStorageStr
67 "(int64_t size, *, int64_t allocator=None, Device device=None)",
68 THPStorageStr
69 "(PyObject* sequence, *, int64_t allocator=None, Device device=None)",
70 });
71 torch::ParsedArgs<3> parsed_args;
72 auto r = parser.parse(args, kwargs, parsed_args);
73
74 int64_t allocator_arg_idx = 0;
75 int64_t device_arg_idx = 1;
76
77 if (r.idx > 0) {
78 allocator_arg_idx = 1;
79 device_arg_idx = 2;
80 }
81
82 c10::optional<int64_t> allocator_opt = r.toInt64Optional(allocator_arg_idx);
83 c10::optional<at::Device> device_opt = r.deviceOptional(device_arg_idx);
84
85 TORCH_CHECK(
86 !allocator_opt.has_value() || !device_opt.has_value(),
87 THPStorageStr,
88 "(): only one or neither of 'allocator' or 'device' can ",
89 "be given, but not both");
90
91 THPStoragePtr self((THPStorage*)type->tp_alloc(type, 0));
92 THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
93 c10::Allocator* allocator = nullptr;
94 at::OptionalDeviceGuard device_guard;
95
96 if (allocator_opt.has_value()) {
97 allocator = reinterpret_cast<c10::Allocator*>(allocator_opt.value());
98 } else if (device_opt.has_value()) {
99 at::Device device = device_opt.value();
100 if (device.type() == at::kCPU) {
101 allocator = c10::GetDefaultCPUAllocator();
102#ifdef USE_CUDA
103 } else if (device.type() == at::kCUDA) {
104 at::globalContext().lazyInitCUDA();
105 allocator = c10::cuda::CUDACachingAllocator::get();
106#endif
107#ifdef USE_MPS
108 } else if (device.type() == at::kMPS) {
109 allocator = at::mps::GetMPSAllocator();
110#endif
111 } else if (device.type() == at::DeviceType::XPU) {
112 allocator = c10::GetAllocator(device.type());
113 } else if (device.type() == at::DeviceType::Meta) {
114 allocator = c10::GetAllocator(device.type());
115 } else {
116 TORCH_CHECK(
117 false,
118 THPStorageStr,
119 "(): Storage device not recognized: ",
120 device.type());
121 }
122 device_guard.reset_device(device);
123 } else {
124 allocator = c10::GetDefaultCPUAllocator();
125 }
126
127 // torch.Storage(*, ...)
128 if (r.idx == 0) {
129 self->cdata = c10::make_intrusive<at::StorageImpl>(
130 c10::StorageImpl::use_byte_size_t(),
131 0,
132 allocator,
133 /*resizable=*/true)
134 .release();
135 return (PyObject*)self.release();
136
137 // torch.Storage(size, *, ...)
138 } else if (r.idx == 1) {
139 int64_t size = r.toInt64(0);
140 self->cdata = c10::make_intrusive<at::StorageImpl>(
141 c10::StorageImpl::use_byte_size_t(),
142 size,
143 allocator,
144 /*resizable=*/true)
145 .release();
146 return (PyObject*)self.release();
147
148 // torch.Storage(sequence, *, ...)
149 } else if (r.idx == 2) {
150 PyObject* sequence = r.pyobject(0);
151 Py_ssize_t length = PySequence_Length(sequence);
152 TORCH_CHECK(
153 PySequence_Check(sequence),
154 THPStorageStr,
155 "(): Expected a sequence type, but got ",
156 THPUtils_typename(sequence));
157 TORCH_CHECK(
158 length >= 0,
159 THPStorageStr,
160 "(): Could not obtain the length of sequence of type ",
161 THPUtils_typename(sequence));
162 self->cdata = c10::make_intrusive<at::StorageImpl>(
163 c10::StorageImpl::use_byte_size_t(),
164 length,
165 allocator,
166 /*resizable=*/true)
167 .release();
168 THPObjectPtr item;
169 try {
170 for (Py_ssize_t i = 0; i < length; i++) {
171 item = PySequence_GetItem(sequence, i);
172 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
173 uint8_t value = THPByteUtils_unpackReal(item.get());
174 if (allocator == c10::GetDefaultCPUAllocator()) {
175 self->cdata->unsafe_data<uint8_t>()[i] = value;
176 } else {
177 // TODO: this might be slow - consider batched updates?
178 storage_set(
179 at::unsafeStorageFromTH(self->cdata, /*retain=*/true), i, value);
180 }
181 }
182 } catch (const std::exception& e) {
183 THPUtils_setError(
184 THPStorageStr
185 "(): tried to construct a storage from a sequence (%s), "
186 "but one of the items was of type %s instead of int",
187 THPUtils_typename(sequence),
188 THPUtils_typename(item.get()));
189 return nullptr;
190 }
191 return (PyObject*)self.release();
192 }
193 Py_RETURN_NONE;
194 END_HANDLE_TH_ERRORS
195}
196
197static Py_ssize_t THPStorage_length(THPStorage* self) {
198 HANDLE_TH_ERRORS
199 return self->cdata->nbytes() / sizeof(uint8_t);
200 END_HANDLE_TH_ERRORS_RET(-1)
201}
202
203static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
204 HANDLE_TH_ERRORS
205 /* Integer index */
206 if (THPUtils_checkLong(index)) {
207 int64_t nindex = THPUtils_unpackLong(index);
208 if (nindex < 0)
209 nindex += (self->cdata->nbytes() / sizeof(uint8_t));
210 if (nindex < 0 ||
211 nindex >=
212 static_cast<int64_t>(self->cdata->nbytes() / sizeof(uint8_t))) {
213 PyErr_SetString(
214 PyExc_IndexError,
215 fmt::format(
216 "index {} out of range for storage of size {}",
217 nindex,
218 self->cdata->nbytes() / sizeof(uint8_t)));
219 return nullptr;
220 }
221 uint8_t value = storage_get(
222 at::unsafeStorageFromTH(self->cdata, /*retain=*/true), nindex);
223 return THPByteUtils_newReal(value);
224 /* Slice index */
225 } else if (PySlice_Check(index)) {
226 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
227 Py_ssize_t start, stop, slicelength, step;
228 int64_t len = self->cdata->nbytes() / sizeof(uint8_t);
229 if (PySlice_GetIndicesEx(index, len, &start, &stop, &step, &slicelength) !=
230 0)
231 return nullptr;
232 if (step != 1) {
233 THPUtils_setError(
234 "Trying to slice with a step of %lld, but only a step of "
235 "1 is supported",
236 (long long)step);
237 return nullptr;
238 }
239
240 uint8_t* data = self->cdata->data<uint8_t>();
241
242 at::StorageImpl* old_storage = self->cdata;
243 c10::raw::intrusive_ptr::incref(old_storage);
244 auto new_storage = c10::make_intrusive<at::StorageImpl>(
245 c10::StorageImpl::use_byte_size_t(),
246#ifdef THQUANTIZED
247 slicelength * sizeof(quantized_t),
248#else
249 slicelength * sizeof(uint8_t),
250#endif
251 at::DataPtr(
252 static_cast<void*>(data + start),
253 old_storage,
254 [](void* s) {
255 c10::raw::intrusive_ptr::decref(static_cast<at::StorageImpl*>(s));
256 },
257 old_storage->device()),
258 old_storage->allocator(),
259 /* resizable */ false);
260
261 PyObject* _ret = THPStorage_New(std::move(new_storage));
262 return _ret;
263 }
264 PyErr_Format(
265 PyExc_TypeError,
266 "can't index a " THPStorageStr " with %s",
267 THPUtils_typename(index));
268 return nullptr;
269 END_HANDLE_TH_ERRORS
270}
271
272static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) {
273 HANDLE_TH_ERRORS
274 if (!THPByteUtils_checkReal(value)) {
275 THPUtils_setError(
276 "can only set storage content with a int types, but got "
277 "%s instead",
278 THPUtils_typename(value));
279 return -1;
280 }
281
282 uint8_t rvalue = THPByteUtils_unpackReal(value);
283 if (THPUtils_checkLong(index)) {
284 int64_t nindex = THPUtils_unpackLong(index);
285 storage_set(
286 at::unsafeStorageFromTH(self->cdata, /*retain=*/true), nindex, rvalue);
287 return 0;
288 } else if (PySlice_Check(index)) {
289 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
290 Py_ssize_t start, stop, slicelength, step;
291 int64_t len = self->cdata->nbytes() / sizeof(uint8_t);
292 if (PySlice_GetIndicesEx(index, len, &start, &stop, &step, &slicelength) !=
293 0)
294 return -1;
295 if (step != 1) {
296 THPUtils_setError(
297 "Trying to slice with a step of %lld, but only a step of "
298 "1 is supported",
299 (long long)step);
300 return 0;
301 }
302 // TODO: check the bounds only once
303 // TODO: fill?
304 for (; start < stop; start++)
305 storage_set(
306 at::unsafeStorageFromTH(self->cdata, /*retain=*/true), start, rvalue);
307 return 0;
308 }
309 THPUtils_setError(
310 "can't index a " THPStorageStr " with %s", THPUtils_typename(index));
311 return -1;
312 END_HANDLE_TH_ERRORS_RET(-1)
313}
314
315static PyMappingMethods THPStorage_mappingmethods = {
316 (lenfunc)THPStorage_length,
317 (binaryfunc)THPStorage_get,
318 (objobjargproc)THPStorage_set};
319
320struct THPStorageMeta {
321 PyHeapTypeObject base;
322};
323
324int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs);
325
326PyTypeObject THPStorageMetaType = {
327 PyVarObject_HEAD_INIT(
328 DEFERRED_ADDRESS(&PyType_Type),
329 0) "torch._C._StorageMeta", /* tp_name */
330 sizeof(THPStorageMeta), /* tp_basicsize */
331 0, /* tp_itemsize */
332 nullptr, /* tp_dealloc */
333 0, /* tp_vectorcall_offset */
334 nullptr, /* tp_getattr */
335 nullptr, /* tp_setattr */
336 nullptr, /* tp_reserved */
337 nullptr, /* tp_repr */
338 nullptr, /* tp_as_number */
339 nullptr, /* tp_as_sequence */
340 nullptr, /* tp_as_mapping */
341 nullptr, /* tp_hash */
342 nullptr, /* tp_call */
343 nullptr, /* tp_str */
344 nullptr, /* tp_getattro */
345 nullptr, /* tp_setattro */
346 nullptr, /* tp_as_buffer */
347 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
348 nullptr, /* tp_doc */
349 nullptr, /* tp_traverse */
350 nullptr, /* tp_clear */
351 nullptr, /* tp_richcompare */
352 0, /* tp_weaklistoffset */
353 nullptr, /* tp_iter */
354 nullptr, /* tp_iternext */
355 nullptr, /* tp_methods */
356 nullptr, /* tp_members */
357 nullptr, /* tp_getset */
358 DEFERRED_ADDRESS(&PyType_Type), /* tp_base */
359 nullptr, /* tp_dict */
360 nullptr, /* tp_descr_get */
361 nullptr, /* tp_descr_set */
362 0, /* tp_dictoffset */
363 THPStorageMetaType_init, /* tp_init */
364 nullptr, /* tp_alloc */
365 nullptr, /* tp_new */
366};
367
368// TODO: implement equality
369PyTypeObject THPStorageType = {
370 PyVarObject_HEAD_INIT(
371 &THPStorageMetaType,
372 0) "torch._C.StorageBase", /* tp_name */
373 sizeof(THPStorage), /* tp_basicsize */
374 0, /* tp_itemsize */
375 nullptr, /* tp_dealloc */
376 0, /* tp_vectorcall_offset */
377 nullptr, /* tp_getattr */
378 nullptr, /* tp_setattr */
379 nullptr, /* tp_reserved */
380 nullptr, /* tp_repr */
381 nullptr, /* tp_as_number */
382 nullptr, /* tp_as_sequence */
383 &THPStorage_mappingmethods, /* tp_as_mapping */
384 nullptr, /* tp_hash */
385 nullptr, /* tp_call */
386 nullptr, /* tp_str */
387 nullptr, /* tp_getattro */
388 nullptr, /* tp_setattro */
389 nullptr, /* tp_as_buffer */
390 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
391 nullptr, /* tp_doc */
392 nullptr, /* tp_traverse */
393 nullptr, /* tp_clear */
394 nullptr, /* tp_richcompare */
395 0, /* tp_weaklistoffset */
396 nullptr, /* tp_iter */
397 nullptr, /* tp_iternext */
398 nullptr,
399 /* will be assigned in init */ /* tp_methods */
400 nullptr,
401 /* will be assigned in init */ /* tp_members */
402 nullptr, /* tp_getset */
403 nullptr, /* tp_base */
404 nullptr, /* tp_dict */
405 nullptr, /* tp_descr_get */
406 nullptr, /* tp_descr_set */
407 0, /* tp_dictoffset */
408 nullptr, /* tp_init */
409 nullptr, /* tp_alloc */
410 THPStorage_pynew, /* tp_new */
411};
412
413int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
414 if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
415 return -1;
416 }
417 ((PyTypeObject*)cls)->tp_dealloc = (destructor)THPStorage_subclass_dealloc;
418 return 0;
419}
420
421// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
422static struct PyMemberDef THPStorage_members[] = {
423 {(char*)"_cdata",
424 T_ULONGLONG,
425 offsetof(THPStorage, cdata),
426 READONLY,
427 nullptr},
428 {nullptr}};
429
430static PyObject* THPStorage_device(THPStorage* self, void* unused) {
431 HANDLE_TH_ERRORS
432 return THPDevice_New(self->cdata->device());
433 END_HANDLE_TH_ERRORS
434}
435
436typedef PyObject* (*getter)(PyObject*, void*);
437
438// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
439static struct PyGetSetDef THPStorage_properties[] = {
440 {"device", (getter)THPStorage_device, nullptr, nullptr, nullptr},
441 {nullptr}};
442
443bool THPStorage_init(PyObject* module) {
444 static std::vector<PyMethodDef> methods;
445 THPUtils_addPyMethodDefs(methods, THPStorage_getMethods());
446 THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods());
447
448 THPStorageMetaType.tp_base = &PyType_Type;
449 if (PyType_Ready(&THPStorageMetaType) < 0)
450 return false;
451 Py_INCREF(&THPStorageMetaType);
452 PyModule_AddObject(module, "_StorageMeta", (PyObject*)&THPStorageMetaType);
453
454 THPStorageType.tp_methods = methods.data();
455 THPStorageType.tp_members = THPStorage_members;
456 THPStorageType.tp_getset = THPStorage_properties;
457 if (PyType_Ready(&THPStorageType) < 0)
458 return false;
459 Py_INCREF(&THPStorageType);
460 PyModule_AddObject(module, "StorageBase", (PyObject*)&THPStorageType);
461 return true;
462}
463
464void THPStorage_postInit(PyObject* module) {
465 THPStorageClass = PyObject_GetAttrString(module, "UntypedStorage");
466 if (!THPStorageClass)
467 throw python_error();
468}
469