1 | #include <torch/csrc/python_headers.h> |
2 | #ifdef _MSC_VER |
3 | #include <c10/util/win32-headers.h> |
4 | #endif |
5 | #include <structmember.h> |
6 | |
7 | #include <ATen/mps/MPSDevice.h> |
8 | #include <c10/core/CPUAllocator.h> |
9 | #include <libshm.h> |
10 | #include <torch/csrc/CudaIPCTypes.h> |
11 | #include <torch/csrc/Device.h> |
12 | #include <torch/csrc/DynamicTypes.h> |
13 | #include <torch/csrc/StorageMethods.h> |
14 | #include <torch/csrc/StorageSharing.h> |
15 | #include <torch/csrc/THP.h> |
16 | #include <torch/csrc/autograd/utils/wrap_outputs.h> |
17 | #include <torch/csrc/copy_utils.h> |
18 | #include <torch/csrc/utils/python_arg_parser.h> |
19 | |
20 | #include <c10/util/intrusive_ptr.h> |
21 | #include <fmt/format.h> |
22 | |
23 | template <> |
24 | void THPPointer<c10::StorageImpl>::free() { |
25 | if (ptr) { |
26 | c10::raw::intrusive_ptr::decref(ptr); |
27 | } |
28 | } |
29 | |
30 | PyObject* THPStorageClass = nullptr; |
31 | |
32 | PyObject* THPStorage_New(c10::intrusive_ptr<c10::StorageImpl> ptr) { |
33 | AT_ASSERT(ptr); |
34 | PyTypeObject* type = (PyTypeObject*)THPStorageClass; |
35 | PyObject* obj = type->tp_alloc(type, 0); |
36 | if (obj) { |
37 | ((THPStorage*)obj)->cdata = ptr.release(); |
38 | } |
39 | return obj; |
40 | } |
41 | |
42 | static void THPStorage_subclass_dealloc(PyObject* self) { |
43 | THPStorage* _self = (THPStorage*)self; |
44 | // Some subclass of StorageBase are GC-tracked objects even |
45 | // though the base class is not. |
46 | auto* type = Py_TYPE(self); |
47 | if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) { |
48 | PyObject_GC_UnTrack(self); |
49 | } |
50 | if (_self->cdata) { |
51 | c10::raw::intrusive_ptr::decref(_self->cdata); |
52 | } |
53 | Py_TYPE(_self)->tp_free(self); |
54 | } |
55 | |
56 | static PyObject* THPStorage_pynew( |
57 | PyTypeObject* type, |
58 | PyObject* args, |
59 | PyObject* kwargs) { |
60 | HANDLE_TH_ERRORS |
61 | TORCH_CHECK( |
62 | type != &THPStorageType, |
63 | "Cannot directly construct StorageBase; subclass it and then construct that" ); |
64 | static torch::PythonArgParser parser({ |
65 | THPStorageStr "(*, int64_t allocator=None, Device device=None)" , |
66 | THPStorageStr |
67 | "(int64_t size, *, int64_t allocator=None, Device device=None)" , |
68 | THPStorageStr |
69 | "(PyObject* sequence, *, int64_t allocator=None, Device device=None)" , |
70 | }); |
71 | torch::ParsedArgs<3> parsed_args; |
72 | auto r = parser.parse(args, kwargs, parsed_args); |
73 | |
74 | int64_t allocator_arg_idx = 0; |
75 | int64_t device_arg_idx = 1; |
76 | |
77 | if (r.idx > 0) { |
78 | allocator_arg_idx = 1; |
79 | device_arg_idx = 2; |
80 | } |
81 | |
82 | c10::optional<int64_t> allocator_opt = r.toInt64Optional(allocator_arg_idx); |
83 | c10::optional<at::Device> device_opt = r.deviceOptional(device_arg_idx); |
84 | |
85 | TORCH_CHECK( |
86 | !allocator_opt.has_value() || !device_opt.has_value(), |
87 | THPStorageStr, |
88 | "(): only one or neither of 'allocator' or 'device' can " , |
89 | "be given, but not both" ); |
90 | |
91 | THPStoragePtr self((THPStorage*)type->tp_alloc(type, 0)); |
92 | THPUtils_assert(self, "failed to allocate a " THPStorageStr " object" ); |
93 | c10::Allocator* allocator = nullptr; |
94 | at::OptionalDeviceGuard device_guard; |
95 | |
96 | if (allocator_opt.has_value()) { |
97 | allocator = reinterpret_cast<c10::Allocator*>(allocator_opt.value()); |
98 | } else if (device_opt.has_value()) { |
99 | at::Device device = device_opt.value(); |
100 | if (device.type() == at::kCPU) { |
101 | allocator = c10::GetDefaultCPUAllocator(); |
102 | #ifdef USE_CUDA |
103 | } else if (device.type() == at::kCUDA) { |
104 | at::globalContext().lazyInitCUDA(); |
105 | allocator = c10::cuda::CUDACachingAllocator::get(); |
106 | #endif |
107 | #ifdef USE_MPS |
108 | } else if (device.type() == at::kMPS) { |
109 | allocator = at::mps::GetMPSAllocator(); |
110 | #endif |
111 | } else if (device.type() == at::DeviceType::XPU) { |
112 | allocator = c10::GetAllocator(device.type()); |
113 | } else if (device.type() == at::DeviceType::Meta) { |
114 | allocator = c10::GetAllocator(device.type()); |
115 | } else { |
116 | TORCH_CHECK( |
117 | false, |
118 | THPStorageStr, |
119 | "(): Storage device not recognized: " , |
120 | device.type()); |
121 | } |
122 | device_guard.reset_device(device); |
123 | } else { |
124 | allocator = c10::GetDefaultCPUAllocator(); |
125 | } |
126 | |
127 | // torch.Storage(*, ...) |
128 | if (r.idx == 0) { |
129 | self->cdata = c10::make_intrusive<at::StorageImpl>( |
130 | c10::StorageImpl::use_byte_size_t(), |
131 | 0, |
132 | allocator, |
133 | /*resizable=*/true) |
134 | .release(); |
135 | return (PyObject*)self.release(); |
136 | |
137 | // torch.Storage(size, *, ...) |
138 | } else if (r.idx == 1) { |
139 | int64_t size = r.toInt64(0); |
140 | self->cdata = c10::make_intrusive<at::StorageImpl>( |
141 | c10::StorageImpl::use_byte_size_t(), |
142 | size, |
143 | allocator, |
144 | /*resizable=*/true) |
145 | .release(); |
146 | return (PyObject*)self.release(); |
147 | |
148 | // torch.Storage(sequence, *, ...) |
149 | } else if (r.idx == 2) { |
150 | PyObject* sequence = r.pyobject(0); |
151 | Py_ssize_t length = PySequence_Length(sequence); |
152 | TORCH_CHECK( |
153 | PySequence_Check(sequence), |
154 | THPStorageStr, |
155 | "(): Expected a sequence type, but got " , |
156 | THPUtils_typename(sequence)); |
157 | TORCH_CHECK( |
158 | length >= 0, |
159 | THPStorageStr, |
160 | "(): Could not obtain the length of sequence of type " , |
161 | THPUtils_typename(sequence)); |
162 | self->cdata = c10::make_intrusive<at::StorageImpl>( |
163 | c10::StorageImpl::use_byte_size_t(), |
164 | length, |
165 | allocator, |
166 | /*resizable=*/true) |
167 | .release(); |
168 | THPObjectPtr item; |
169 | try { |
170 | for (Py_ssize_t i = 0; i < length; i++) { |
171 | item = PySequence_GetItem(sequence, i); |
172 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
173 | uint8_t value = THPByteUtils_unpackReal(item.get()); |
174 | if (allocator == c10::GetDefaultCPUAllocator()) { |
175 | self->cdata->unsafe_data<uint8_t>()[i] = value; |
176 | } else { |
177 | // TODO: this might be slow - consider batched updates? |
178 | storage_set( |
179 | at::unsafeStorageFromTH(self->cdata, /*retain=*/true), i, value); |
180 | } |
181 | } |
182 | } catch (const std::exception& e) { |
183 | THPUtils_setError( |
184 | THPStorageStr |
185 | "(): tried to construct a storage from a sequence (%s), " |
186 | "but one of the items was of type %s instead of int" , |
187 | THPUtils_typename(sequence), |
188 | THPUtils_typename(item.get())); |
189 | return nullptr; |
190 | } |
191 | return (PyObject*)self.release(); |
192 | } |
193 | Py_RETURN_NONE; |
194 | END_HANDLE_TH_ERRORS |
195 | } |
196 | |
197 | static Py_ssize_t THPStorage_length(THPStorage* self) { |
198 | HANDLE_TH_ERRORS |
199 | return self->cdata->nbytes() / sizeof(uint8_t); |
200 | END_HANDLE_TH_ERRORS_RET(-1) |
201 | } |
202 | |
203 | static PyObject* THPStorage_get(THPStorage* self, PyObject* index) { |
204 | HANDLE_TH_ERRORS |
205 | /* Integer index */ |
206 | if (THPUtils_checkLong(index)) { |
207 | int64_t nindex = THPUtils_unpackLong(index); |
208 | if (nindex < 0) |
209 | nindex += (self->cdata->nbytes() / sizeof(uint8_t)); |
210 | if (nindex < 0 || |
211 | nindex >= |
212 | static_cast<int64_t>(self->cdata->nbytes() / sizeof(uint8_t))) { |
213 | PyErr_SetString( |
214 | PyExc_IndexError, |
215 | fmt::format( |
216 | "index {} out of range for storage of size {}" , |
217 | nindex, |
218 | self->cdata->nbytes() / sizeof(uint8_t))); |
219 | return nullptr; |
220 | } |
221 | uint8_t value = storage_get( |
222 | at::unsafeStorageFromTH(self->cdata, /*retain=*/true), nindex); |
223 | return THPByteUtils_newReal(value); |
224 | /* Slice index */ |
225 | } else if (PySlice_Check(index)) { |
226 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
227 | Py_ssize_t start, stop, slicelength, step; |
228 | int64_t len = self->cdata->nbytes() / sizeof(uint8_t); |
229 | if (PySlice_GetIndicesEx(index, len, &start, &stop, &step, &slicelength) != |
230 | 0) |
231 | return nullptr; |
232 | if (step != 1) { |
233 | THPUtils_setError( |
234 | "Trying to slice with a step of %lld, but only a step of " |
235 | "1 is supported" , |
236 | (long long)step); |
237 | return nullptr; |
238 | } |
239 | |
240 | uint8_t* data = self->cdata->data<uint8_t>(); |
241 | |
242 | at::StorageImpl* old_storage = self->cdata; |
243 | c10::raw::intrusive_ptr::incref(old_storage); |
244 | auto new_storage = c10::make_intrusive<at::StorageImpl>( |
245 | c10::StorageImpl::use_byte_size_t(), |
246 | #ifdef THQUANTIZED |
247 | slicelength * sizeof(quantized_t), |
248 | #else |
249 | slicelength * sizeof(uint8_t), |
250 | #endif |
251 | at::DataPtr( |
252 | static_cast<void*>(data + start), |
253 | old_storage, |
254 | [](void* s) { |
255 | c10::raw::intrusive_ptr::decref(static_cast<at::StorageImpl*>(s)); |
256 | }, |
257 | old_storage->device()), |
258 | old_storage->allocator(), |
259 | /* resizable */ false); |
260 | |
261 | PyObject* _ret = THPStorage_New(std::move(new_storage)); |
262 | return _ret; |
263 | } |
264 | PyErr_Format( |
265 | PyExc_TypeError, |
266 | "can't index a " THPStorageStr " with %s" , |
267 | THPUtils_typename(index)); |
268 | return nullptr; |
269 | END_HANDLE_TH_ERRORS |
270 | } |
271 | |
272 | static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) { |
273 | HANDLE_TH_ERRORS |
274 | if (!THPByteUtils_checkReal(value)) { |
275 | THPUtils_setError( |
276 | "can only set storage content with a int types, but got " |
277 | "%s instead" , |
278 | THPUtils_typename(value)); |
279 | return -1; |
280 | } |
281 | |
282 | uint8_t rvalue = THPByteUtils_unpackReal(value); |
283 | if (THPUtils_checkLong(index)) { |
284 | int64_t nindex = THPUtils_unpackLong(index); |
285 | storage_set( |
286 | at::unsafeStorageFromTH(self->cdata, /*retain=*/true), nindex, rvalue); |
287 | return 0; |
288 | } else if (PySlice_Check(index)) { |
289 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
290 | Py_ssize_t start, stop, slicelength, step; |
291 | int64_t len = self->cdata->nbytes() / sizeof(uint8_t); |
292 | if (PySlice_GetIndicesEx(index, len, &start, &stop, &step, &slicelength) != |
293 | 0) |
294 | return -1; |
295 | if (step != 1) { |
296 | THPUtils_setError( |
297 | "Trying to slice with a step of %lld, but only a step of " |
298 | "1 is supported" , |
299 | (long long)step); |
300 | return 0; |
301 | } |
302 | // TODO: check the bounds only once |
303 | // TODO: fill? |
304 | for (; start < stop; start++) |
305 | storage_set( |
306 | at::unsafeStorageFromTH(self->cdata, /*retain=*/true), start, rvalue); |
307 | return 0; |
308 | } |
309 | THPUtils_setError( |
310 | "can't index a " THPStorageStr " with %s" , THPUtils_typename(index)); |
311 | return -1; |
312 | END_HANDLE_TH_ERRORS_RET(-1) |
313 | } |
314 | |
315 | static PyMappingMethods THPStorage_mappingmethods = { |
316 | (lenfunc)THPStorage_length, |
317 | (binaryfunc)THPStorage_get, |
318 | (objobjargproc)THPStorage_set}; |
319 | |
320 | struct THPStorageMeta { |
321 | PyHeapTypeObject base; |
322 | }; |
323 | |
324 | int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs); |
325 | |
326 | PyTypeObject THPStorageMetaType = { |
327 | PyVarObject_HEAD_INIT( |
328 | DEFERRED_ADDRESS(&PyType_Type), |
329 | 0) "torch._C._StorageMeta" , /* tp_name */ |
330 | sizeof(THPStorageMeta), /* tp_basicsize */ |
331 | 0, /* tp_itemsize */ |
332 | nullptr, /* tp_dealloc */ |
333 | 0, /* tp_vectorcall_offset */ |
334 | nullptr, /* tp_getattr */ |
335 | nullptr, /* tp_setattr */ |
336 | nullptr, /* tp_reserved */ |
337 | nullptr, /* tp_repr */ |
338 | nullptr, /* tp_as_number */ |
339 | nullptr, /* tp_as_sequence */ |
340 | nullptr, /* tp_as_mapping */ |
341 | nullptr, /* tp_hash */ |
342 | nullptr, /* tp_call */ |
343 | nullptr, /* tp_str */ |
344 | nullptr, /* tp_getattro */ |
345 | nullptr, /* tp_setattro */ |
346 | nullptr, /* tp_as_buffer */ |
347 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ |
348 | nullptr, /* tp_doc */ |
349 | nullptr, /* tp_traverse */ |
350 | nullptr, /* tp_clear */ |
351 | nullptr, /* tp_richcompare */ |
352 | 0, /* tp_weaklistoffset */ |
353 | nullptr, /* tp_iter */ |
354 | nullptr, /* tp_iternext */ |
355 | nullptr, /* tp_methods */ |
356 | nullptr, /* tp_members */ |
357 | nullptr, /* tp_getset */ |
358 | DEFERRED_ADDRESS(&PyType_Type), /* tp_base */ |
359 | nullptr, /* tp_dict */ |
360 | nullptr, /* tp_descr_get */ |
361 | nullptr, /* tp_descr_set */ |
362 | 0, /* tp_dictoffset */ |
363 | THPStorageMetaType_init, /* tp_init */ |
364 | nullptr, /* tp_alloc */ |
365 | nullptr, /* tp_new */ |
366 | }; |
367 | |
368 | // TODO: implement equality |
369 | PyTypeObject THPStorageType = { |
370 | PyVarObject_HEAD_INIT( |
371 | &THPStorageMetaType, |
372 | 0) "torch._C.StorageBase" , /* tp_name */ |
373 | sizeof(THPStorage), /* tp_basicsize */ |
374 | 0, /* tp_itemsize */ |
375 | nullptr, /* tp_dealloc */ |
376 | 0, /* tp_vectorcall_offset */ |
377 | nullptr, /* tp_getattr */ |
378 | nullptr, /* tp_setattr */ |
379 | nullptr, /* tp_reserved */ |
380 | nullptr, /* tp_repr */ |
381 | nullptr, /* tp_as_number */ |
382 | nullptr, /* tp_as_sequence */ |
383 | &THPStorage_mappingmethods, /* tp_as_mapping */ |
384 | nullptr, /* tp_hash */ |
385 | nullptr, /* tp_call */ |
386 | nullptr, /* tp_str */ |
387 | nullptr, /* tp_getattro */ |
388 | nullptr, /* tp_setattro */ |
389 | nullptr, /* tp_as_buffer */ |
390 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ |
391 | nullptr, /* tp_doc */ |
392 | nullptr, /* tp_traverse */ |
393 | nullptr, /* tp_clear */ |
394 | nullptr, /* tp_richcompare */ |
395 | 0, /* tp_weaklistoffset */ |
396 | nullptr, /* tp_iter */ |
397 | nullptr, /* tp_iternext */ |
398 | nullptr, |
399 | /* will be assigned in init */ /* tp_methods */ |
400 | nullptr, |
401 | /* will be assigned in init */ /* tp_members */ |
402 | nullptr, /* tp_getset */ |
403 | nullptr, /* tp_base */ |
404 | nullptr, /* tp_dict */ |
405 | nullptr, /* tp_descr_get */ |
406 | nullptr, /* tp_descr_set */ |
407 | 0, /* tp_dictoffset */ |
408 | nullptr, /* tp_init */ |
409 | nullptr, /* tp_alloc */ |
410 | THPStorage_pynew, /* tp_new */ |
411 | }; |
412 | |
413 | int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) { |
414 | if (PyType_Type.tp_init(cls, args, kwargs) < 0) { |
415 | return -1; |
416 | } |
417 | ((PyTypeObject*)cls)->tp_dealloc = (destructor)THPStorage_subclass_dealloc; |
418 | return 0; |
419 | } |
420 | |
421 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
422 | static struct PyMemberDef THPStorage_members[] = { |
423 | {(char*)"_cdata" , |
424 | T_ULONGLONG, |
425 | offsetof(THPStorage, cdata), |
426 | READONLY, |
427 | nullptr}, |
428 | {nullptr}}; |
429 | |
430 | static PyObject* THPStorage_device(THPStorage* self, void* unused) { |
431 | HANDLE_TH_ERRORS |
432 | return THPDevice_New(self->cdata->device()); |
433 | END_HANDLE_TH_ERRORS |
434 | } |
435 | |
436 | typedef PyObject* (*getter)(PyObject*, void*); |
437 | |
438 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
439 | static struct PyGetSetDef THPStorage_properties[] = { |
440 | {"device" , (getter)THPStorage_device, nullptr, nullptr, nullptr}, |
441 | {nullptr}}; |
442 | |
443 | bool THPStorage_init(PyObject* module) { |
444 | static std::vector<PyMethodDef> methods; |
445 | THPUtils_addPyMethodDefs(methods, THPStorage_getMethods()); |
446 | THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods()); |
447 | |
448 | THPStorageMetaType.tp_base = &PyType_Type; |
449 | if (PyType_Ready(&THPStorageMetaType) < 0) |
450 | return false; |
451 | Py_INCREF(&THPStorageMetaType); |
452 | PyModule_AddObject(module, "_StorageMeta" , (PyObject*)&THPStorageMetaType); |
453 | |
454 | THPStorageType.tp_methods = methods.data(); |
455 | THPStorageType.tp_members = THPStorage_members; |
456 | THPStorageType.tp_getset = THPStorage_properties; |
457 | if (PyType_Ready(&THPStorageType) < 0) |
458 | return false; |
459 | Py_INCREF(&THPStorageType); |
460 | PyModule_AddObject(module, "StorageBase" , (PyObject*)&THPStorageType); |
461 | return true; |
462 | } |
463 | |
464 | void THPStorage_postInit(PyObject* module) { |
465 | THPStorageClass = PyObject_GetAttrString(module, "UntypedStorage" ); |
466 | if (!THPStorageClass) |
467 | throw python_error(); |
468 | } |
469 | |