1 | #include <torch/csrc/python_headers.h> |
2 | #ifdef _MSC_VER |
3 | #include <c10/util/win32-headers.h> |
4 | #endif |
5 | #include <structmember.h> |
6 | |
7 | #include <c10/core/CPUAllocator.h> |
8 | #include <libshm.h> |
9 | #include <torch/csrc/CudaIPCTypes.h> |
10 | #include <torch/csrc/Device.h> |
11 | #include <torch/csrc/DynamicTypes.h> |
12 | #include <torch/csrc/THP.h> |
13 | #include <torch/csrc/autograd/utils/wrap_outputs.h> |
14 | #include <torch/csrc/copy_utils.h> |
15 | |
16 | #include <c10/util/intrusive_ptr.h> |
17 | #include <fmt/format.h> |
18 | |
19 | #include <torch/csrc/Storage.h> |
20 | #include <torch/csrc/StorageSharing.h> |
21 | |
22 | #ifdef USE_CUDA |
23 | #include <c10/cuda/CUDAGuard.h> |
24 | #include <cuda.h> |
25 | #include <cuda_runtime.h> |
26 | #endif |
27 | |
28 | #include <ATen/MapAllocator.h> |
29 | #include <torch/csrc/utils/python_numbers.h> |
30 | #include <atomic> |
31 | #include <string> |
32 | |
33 | static PyObject* THPStorage_sharedDecref(PyObject* _self, PyObject* noargs) { |
34 | HANDLE_TH_ERRORS |
35 | auto self = (THPStorage*)_self; |
36 | c10::DeviceType device_type = self->cdata->device_type(); |
37 | if (device_type == at::kCPU) { |
38 | c10::StorageImpl* storage = self->cdata; |
39 | THManagedMapAllocator* ctx = |
40 | THManagedMapAllocator::fromDataPtr(storage->data_ptr()); |
41 | if (ctx) { |
42 | ctx->decref(); |
43 | } |
44 | } |
45 | Py_INCREF(self); |
46 | return (PyObject*)self; |
47 | END_HANDLE_TH_ERRORS |
48 | } |
49 | |
50 | static PyObject* THPStorage_sharedIncref(PyObject* _self, PyObject* noargs) { |
51 | HANDLE_TH_ERRORS |
52 | auto self = (THPStorage*)_self; |
53 | c10::DeviceType device_type = self->cdata->device_type(); |
54 | if (device_type == at::kCPU) { |
55 | c10::StorageImpl* storage = self->cdata; |
56 | THManagedMapAllocator* ctx = |
57 | THManagedMapAllocator::fromDataPtr(storage->data_ptr()); |
58 | if (ctx) { |
59 | ctx->incref(); |
60 | } |
61 | } |
62 | Py_RETURN_NONE; |
63 | END_HANDLE_TH_ERRORS |
64 | } |
65 | |
66 | static PyObject* THPStorage_pyNewFilenameStorage( |
67 | PyObject* _unused, |
68 | PyObject* args) { |
69 | HANDLE_TH_ERRORS |
70 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
71 | long long size; |
72 | if (!PyArg_ParseTuple(args, "L" , &size)) { |
73 | return nullptr; |
74 | } |
75 | |
76 | int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE; |
77 | std::string handle = at::NewProcessWideShmHandle(); |
78 | return THPStorage_New(c10::make_intrusive<at::StorageImpl>( |
79 | c10::StorageImpl::use_byte_size_t(), |
80 | size, |
81 | THManagedMapAllocator::makeDataPtr("" , handle.c_str(), flags, size), |
82 | /*allocator=*/nullptr, |
83 | /*resizable=*/false)); |
84 | END_HANDLE_TH_ERRORS |
85 | } |
86 | |
87 | static PyObject* THPStorage_shareFilename(PyObject* _self, PyObject* noargs) { |
88 | HANDLE_TH_ERRORS |
89 | TORCH_CHECK( |
90 | reinterpret_cast<THPStorage*>(_self)->cdata->device_type() == at::kCPU, |
91 | "_share_filename_: only available on CPU" ); |
92 | auto self = (THPStorage*)_self; |
93 | c10::StorageImpl* storage = self->cdata; |
94 | THManagedMapAllocator* ctx = |
95 | THManagedMapAllocator::fromDataPtr(storage->data_ptr()); |
96 | // Storage is already in shared memory, just return a handle |
97 | if (ctx) { |
98 | // done |
99 | } else { |
100 | // TODO: retry on collision |
101 | // TODO: free GIL - but remember to reacquire it when an exception is thrown |
102 | int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE; |
103 | std::string handle = at::NewProcessWideShmHandle(); |
104 | at::Storage new_storage(c10::make_intrusive<at::StorageImpl>( |
105 | c10::StorageImpl::use_byte_size_t(), |
106 | storage->nbytes(), |
107 | THManagedMapAllocator::makeDataPtr( |
108 | "" , handle.c_str(), flags, storage->nbytes()), |
109 | /*allocator=*/nullptr, |
110 | /*resizable=*/false)); |
111 | |
112 | at::Storage _self_aten = torch::createStorage(_self); |
113 | { |
114 | // Copying into shared memory can be slow, so release the GIL |
115 | pybind11::gil_scoped_release no_gil; |
116 | storage_copy(new_storage, _self_aten); |
117 | } |
118 | |
119 | std::swap(*storage, *new_storage.unsafeGetStorageImpl()); |
120 | ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr()); |
121 | AT_ASSERT(ctx); |
122 | } |
123 | |
124 | THPObjectPtr manager_handle(PyBytes_FromString(ctx->manager_handle())); |
125 | if (!manager_handle) |
126 | return nullptr; |
127 | THPObjectPtr storage_handle(PyBytes_FromString(ctx->filename())); |
128 | if (!storage_handle) |
129 | return nullptr; |
130 | THPObjectPtr size(THPUtils_packUInt64(storage->nbytes() / sizeof(uint8_t))); |
131 | if (!size) |
132 | return nullptr; |
133 | |
134 | THPObjectPtr tuple(PyTuple_New(3)); |
135 | if (!tuple) |
136 | return nullptr; |
137 | PyTuple_SET_ITEM(tuple.get(), 0, manager_handle.release()); |
138 | PyTuple_SET_ITEM(tuple.get(), 1, storage_handle.release()); |
139 | PyTuple_SET_ITEM(tuple.get(), 2, size.release()); |
140 | return tuple.release(); |
141 | END_HANDLE_TH_ERRORS |
142 | } |
143 | |
144 | static PyObject* THPStorage_newSharedFilename( |
145 | PyObject* _unused, |
146 | PyObject* args) { |
147 | HANDLE_TH_ERRORS |
148 | THPUtils_assert(PyTuple_GET_SIZE(args) == 3, "tuple of 3 items expected" ); |
149 | PyObject* _manager_handle = PyTuple_GET_ITEM(args, 0); |
150 | PyObject* _object_handle = PyTuple_GET_ITEM(args, 1); |
151 | PyObject* _size = PyTuple_GET_ITEM(args, 2); |
152 | if (!PyBytes_Check(_manager_handle) || !PyBytes_Check(_object_handle) || |
153 | !THPUtils_checkLong(_size)) { |
154 | THPUtils_invalidArguments( |
155 | args, |
156 | nullptr, |
157 | "_new_shared in file system mode" , |
158 | 1, |
159 | "a handle (string/bytes) and storage size (int)" ); |
160 | return nullptr; |
161 | } |
162 | const char* manager_handle = PyBytes_AS_STRING(_manager_handle); |
163 | const char* object_handle = PyBytes_AS_STRING(_object_handle); |
164 | int64_t size = THPUtils_unpackLong(_size); |
165 | int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE; |
166 | return THPStorage_New(c10::make_intrusive<at::StorageImpl>( |
167 | c10::StorageImpl::use_byte_size_t(), |
168 | size, |
169 | THManagedMapAllocator::makeDataPtr( |
170 | manager_handle, object_handle, flags, size), |
171 | /*allocator=*/nullptr, |
172 | /*resizable=*/false)); |
173 | END_HANDLE_TH_ERRORS |
174 | } |
175 | |
176 | static c10::intrusive_ptr<c10::StorageImpl> THPStorage_newFdStorage( |
177 | ptrdiff_t size) { |
178 | int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE | |
179 | at::ALLOCATOR_MAPPED_KEEPFD | at::ALLOCATOR_MAPPED_UNLINK; |
180 | std::string handle = at::NewProcessWideShmHandle(); |
181 | auto sptr = at::MapAllocator::makeDataPtr( |
182 | handle, flags, size * sizeof(uint8_t), nullptr); |
183 | return c10::make_intrusive<at::StorageImpl>( |
184 | c10::StorageImpl::use_byte_size_t(), |
185 | size, |
186 | std::move(sptr), |
187 | /*allocator=*/nullptr, |
188 | /*resizable=*/false); |
189 | } |
190 | |
191 | static PyObject* THPStorage_pyNewFdStorage(PyObject* _unused, PyObject* args) { |
192 | HANDLE_TH_ERRORS |
193 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
194 | long long size; |
195 | if (!PyArg_ParseTuple(args, "L" , &size)) { |
196 | return nullptr; |
197 | } |
198 | return THPStorage_New(THPStorage_newFdStorage(size)); |
199 | END_HANDLE_TH_ERRORS |
200 | } |
201 | |
202 | static PyObject* THPStorage_shareFd(PyObject* _self, PyObject* noargs) { |
203 | HANDLE_TH_ERRORS |
204 | TORCH_CHECK( |
205 | reinterpret_cast<THPStorage*>(_self)->cdata->device_type() == at::kCPU, |
206 | "_share_fd_: only available on CPU" ); |
207 | auto self = (THPStorage*)_self; |
208 | c10::StorageImpl* storage = self->cdata; |
209 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
210 | at::MapAllocator* ctx; |
211 | // Storage is already in shared memory, just return a handle |
212 | if ((ctx = at::MapAllocator::fromDataPtr(storage->data_ptr()))) { |
213 | // done |
214 | } else { |
215 | at::Storage new_storage(THPStorage_newFdStorage(storage->nbytes())); |
216 | at::Storage _self_aten = torch::createStorage(_self); |
217 | { |
218 | // Copying into shared memory can be slow, so release the GIL |
219 | pybind11::gil_scoped_release no_gil; |
220 | storage_copy(new_storage, _self_aten); |
221 | } |
222 | |
223 | std::swap(*storage, *new_storage.unsafeGetStorageImpl()); |
224 | ctx = at::MapAllocator::fromDataPtr(storage->data_ptr()); |
225 | AT_ASSERT(ctx); |
226 | } |
227 | |
228 | THPObjectPtr storage_handle(THPUtils_packInt32(ctx->fd())); |
229 | if (!storage_handle) |
230 | return nullptr; |
231 | THPObjectPtr size(THPUtils_packUInt64(storage->nbytes() / sizeof(uint8_t))); |
232 | if (!size) |
233 | return nullptr; |
234 | |
235 | THPObjectPtr tuple(PyTuple_New(2)); |
236 | if (!tuple) |
237 | return nullptr; |
238 | PyTuple_SET_ITEM(tuple.get(), 0, storage_handle.release()); |
239 | PyTuple_SET_ITEM(tuple.get(), 1, size.release()); |
240 | return tuple.release(); |
241 | END_HANDLE_TH_ERRORS |
242 | } |
243 | |
244 | static PyObject* THPStorage_newSharedFd(PyObject* _unused, PyObject* args) { |
245 | HANDLE_TH_ERRORS |
246 | THPUtils_assert(PyTuple_GET_SIZE(args) == 2, "tuple of 2 items expected" ); |
247 | PyObject* _tmp_fd = PyTuple_GET_ITEM(args, 0); |
248 | PyObject* _size = PyTuple_GET_ITEM(args, 1); |
249 | if (!THPUtils_checkLong(_tmp_fd) || !THPUtils_checkLong(_size)) { |
250 | THPUtils_invalidArguments( |
251 | args, |
252 | nullptr, |
253 | "_new_shared in file descriptor mode" , |
254 | 1, |
255 | "a file descriptor (int) and storage size (int)" ); |
256 | return nullptr; |
257 | } |
258 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
259 | int fd; |
260 | int tmp_fd = (int)THPUtils_unpackLong(_tmp_fd); |
261 | int64_t size = THPUtils_unpackLong(_size); |
262 | if ((fd = dup(tmp_fd)) == -1) { |
263 | THPUtils_setError("could not duplicate a shared memory file descriptor" ); |
264 | return nullptr; |
265 | } |
266 | |
267 | int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE | |
268 | at::ALLOCATOR_MAPPED_KEEPFD | at::ALLOCATOR_MAPPED_FROMFD; |
269 | return THPStorage_New(c10::make_intrusive<at::StorageImpl>( |
270 | c10::StorageImpl::use_byte_size_t(), |
271 | size, |
272 | at::MapAllocator::makeDataPtr(at::WITH_FD, "" , fd, flags, size, nullptr), |
273 | /*allocator=*/nullptr, |
274 | /*resizable=*/false)); |
275 | END_HANDLE_TH_ERRORS |
276 | } |
277 | |
278 | static PyObject* THPStorage_shareCuda(PyObject* _self, PyObject* noargs) { |
279 | HANDLE_TH_ERRORS |
280 | #ifdef USE_CUDA |
281 | TORCH_CHECK( |
282 | reinterpret_cast<THPStorage*>(_self)->cdata->device_type() == at::kCUDA, |
283 | "_share_cuda_: only available on CUDA" ); |
284 | auto self = (THPStorage*)_self; |
285 | c10::StorageImpl* storage = self->cdata; |
286 | |
287 | if (storage->received_cuda()) { |
288 | AT_ERROR( |
289 | "Attempted to send CUDA tensor received from another process; this is not currently supported. Consider cloning before sending." ); |
290 | } |
291 | |
292 | at::DeviceGuard device_guard(storage->device()); |
293 | THPObjectPtr tuple(PyTuple_New(8)); |
294 | THPObjectPtr device(THPUtils_packInt32(storage->device().index())); |
295 | THPObjectPtr _handle(Py_None); |
296 | Py_INCREF(Py_None); |
297 | THPObjectPtr size_bytes(THPUtils_packUInt64(storage->nbytes())); |
298 | THPObjectPtr _offset_bytes(THPUtils_packInt32(0)); |
299 | THPObjectPtr _ref_counter(Py_None); |
300 | Py_INCREF(Py_None); |
301 | THPObjectPtr _ref_counter_offset(THPUtils_packInt32(0)); |
302 | THPObjectPtr _event_handle(Py_None); |
303 | Py_INCREF(Py_None); |
304 | THPObjectPtr _event_sync_required(Py_None); |
305 | Py_INCREF(Py_None); |
306 | if (storage->data<uint8_t>()) { |
307 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
308 | size_t base_size; |
309 | void* base_ptr = c10::cuda::CUDACachingAllocator::getBaseAllocation( |
310 | storage->data<uint8_t>(), &base_size); |
311 | ptrdiff_t offset_bytes = (char*)storage->data<uint8_t>() - (char*)base_ptr; |
312 | |
313 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
314 | cudaIpcMemHandle_t handle; |
315 | C10_CUDA_CHECK(cudaIpcGetMemHandle(&handle, base_ptr)); |
316 | |
317 | _handle = PyBytes_FromStringAndSize((char*)&handle, CUDA_IPC_HANDLE_SIZE); |
318 | _offset_bytes = PyLong_FromSsize_t((Py_ssize_t)offset_bytes); |
319 | |
320 | // Put Storage Data behind new ref counting context |
321 | // See Note [CUDA IPC Refcounting implementation explained] |
322 | at::DataPtr sent_data_ptr = |
323 | torch::GetNewRefCountedSentData(storage->data(), storage->device()); |
324 | auto old_data_ptr = storage->set_data_ptr(std::move(sent_data_ptr)); |
325 | auto sent_data = |
326 | static_cast<torch::CudaIPCSentData*>(storage->data_ptr().get_context()); |
327 | sent_data->set_original_ptr(std::move(old_data_ptr)); |
328 | _ref_counter = PyBytes_FromString((sent_data->handle()).c_str()); |
329 | _ref_counter_offset = THPUtils_packInt64(sent_data->offset()); |
330 | |
331 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
332 | cudaIpcEventHandle_t ipc_event_handle; |
333 | |
334 | if (sent_data->event_sync_required_) { |
335 | C10_CUDA_CHECK( |
336 | cudaIpcGetEventHandle(&ipc_event_handle, sent_data->event_)); |
337 | } |
338 | |
339 | _event_handle = PyBytes_FromStringAndSize( |
340 | (char*)&ipc_event_handle, CUDA_IPC_HANDLE_SIZE); |
341 | _event_sync_required = PyBool_FromLong(sent_data->event_sync_required_); |
342 | } |
343 | |
344 | if (!tuple || !device || !_handle || !size_bytes || !_offset_bytes || |
345 | !_event_handle) { |
346 | return nullptr; |
347 | } |
348 | PyTuple_SET_ITEM(tuple.get(), 0, device.release()); |
349 | // cudaIpcMemHandle_t(of basePtr) |
350 | PyTuple_SET_ITEM(tuple.get(), 1, _handle.release()); |
351 | // Size(in bytes) of the real storage, note this is not the size of basePtr |
352 | // memory block. |
353 | PyTuple_SET_ITEM(tuple.get(), 2, size_bytes.release()); |
354 | // Offset(in bytes) of the real storage in the basePtr memory block. |
355 | // NB: this offset MUST be in bytes instead of numel, since we use |
356 | // (storage_handle, offset) |
357 | // as key in shared_cache(multiprocessing/reduction.py). |
358 | // Offset in numel cannot uniquely represent a storage. |
359 | PyTuple_SET_ITEM(tuple.get(), 3, _offset_bytes.release()); |
360 | PyTuple_SET_ITEM(tuple.get(), 4, _ref_counter.release()); |
361 | PyTuple_SET_ITEM(tuple.get(), 5, _ref_counter_offset.release()); |
362 | PyTuple_SET_ITEM(tuple.get(), 6, _event_handle.release()); |
363 | PyTuple_SET_ITEM(tuple.get(), 7, _event_sync_required.release()); |
364 | return tuple.release(); |
365 | #else |
366 | TORCH_CHECK(false, "CUDA is not available" ); |
367 | #endif |
368 | END_HANDLE_TH_ERRORS |
369 | } |
370 | |
371 | static PyObject* THPStorage_releaseIPCCounter( |
372 | PyObject* _unused, |
373 | PyObject* args) { |
374 | HANDLE_TH_ERRORS |
375 | #ifdef USE_CUDA |
376 | THPUtils_assert(PyTuple_GET_SIZE(args) == 2, "tuple of 2 items expected" ); |
377 | PyObject* _ref_counter = PyTuple_GET_ITEM(args, 0); |
378 | PyObject* _ref_counter_offset = PyTuple_GET_ITEM(args, 1); |
379 | if (!(PyBytes_Check(_ref_counter) && |
380 | THPUtils_checkLong(_ref_counter_offset))) { |
381 | THPUtils_invalidArguments( |
382 | args, |
383 | nullptr, |
384 | "_release_ipc_counter in CUDA mode" , |
385 | 1, |
386 | "(bytes _ref_counter, int _ref_counter_offset)" ); |
387 | return nullptr; |
388 | } |
389 | std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter); |
390 | ptrdiff_t ref_counter_offset = |
391 | (ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset); |
392 | // We don't want to break existing code, so resource deletion is best |
393 | // effort basis. Exception expected if producer process terminated |
394 | // before consumer released data. |
395 | int flags = at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE; |
396 | try { |
397 | auto sptr = at::RefcountedMapAllocator::makeDataPtr( |
398 | ref_counter_handle.c_str(), |
399 | flags, |
400 | sizeof(int64_t) * torch::CUDA_IPC_REF_COUNTER_FILE_SIZE, |
401 | nullptr); |
402 | *(static_cast<int64_t*>(sptr.get()) + ref_counter_offset) -= 1; |
403 | } catch (c10::Error& err) { |
404 | // Already warned inside of producer process |
405 | } |
406 | Py_RETURN_NONE; |
407 | #else |
408 | TORCH_CHECK(false, "CUDA is not available" ); |
409 | #endif |
410 | END_HANDLE_TH_ERRORS |
411 | } |
412 | |
413 | #ifdef USE_CUDA |
414 | static std::string THPStorage_bytesAsHandleString(PyObject* handle) { |
415 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
416 | char* buffer; |
417 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
418 | Py_ssize_t handle_size; |
419 | if (PyBytes_AsStringAndSize(handle, &buffer, &handle_size) == -1) { |
420 | // NOLINTNEXTLINE(bugprone-string-constructor) |
421 | return nullptr; |
422 | } |
423 | // NOLINTNEXTLINE(bugprone-string-constructor) |
424 | THPUtils_assert(handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size" ); |
425 | return std::string(buffer, handle_size); |
426 | } |
427 | #endif |
428 | |
429 | static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) { |
430 | HANDLE_TH_ERRORS |
431 | #ifdef USE_CUDA |
432 | THPUtils_assert(PyTuple_GET_SIZE(args) == 8, "tuple of 8 items expected" ); |
433 | PyObject* _device = PyTuple_GET_ITEM(args, 0); |
434 | PyObject* _handle = PyTuple_GET_ITEM(args, 1); |
435 | PyObject* _size_bytes = PyTuple_GET_ITEM(args, 2); |
436 | PyObject* _offset_bytes = PyTuple_GET_ITEM(args, 3); |
437 | PyObject* _ref_counter = PyTuple_GET_ITEM(args, 4); |
438 | PyObject* _ref_counter_offset = PyTuple_GET_ITEM(args, 5); |
439 | PyObject* _event_handle = PyTuple_GET_ITEM(args, 6); |
440 | PyObject* _event_sync_required = PyTuple_GET_ITEM(args, 7); |
441 | if (!(THPUtils_checkLong(_device) && THPUtils_checkLong(_size_bytes) && |
442 | PyBytes_Check(_handle) && PyBytes_Check(_ref_counter) && |
443 | PyBytes_Check(_event_handle) && THPUtils_checkLong(_offset_bytes) && |
444 | THPUtils_checkLong(_ref_counter_offset) && |
445 | PyBool_Check(_event_sync_required))) { |
446 | THPUtils_invalidArguments( |
447 | args, |
448 | nullptr, |
449 | "_new_shared in CUDA mode" , |
450 | 1, |
451 | "(int device, bytes handle, int storage_size_bytes, int storage_offset_bytes, bytes _ref_counter, int _ref_counter_offset, bytes event_handle, bool event_sync_required)" ); |
452 | return nullptr; |
453 | } |
454 | |
455 | size_t storage_size = |
456 | (size_t)THPUtils_unpackLong(_size_bytes) / sizeof(uint8_t); |
457 | ptrdiff_t storage_offset_bytes = |
458 | (ptrdiff_t)THPUtils_unpackLong(_offset_bytes); |
459 | |
460 | int64_t device = THPUtils_unpackLong(_device); |
461 | at::cuda::CUDAGuard device_guard(device); |
462 | |
463 | if (PyObject_IsTrue(_event_sync_required)) { |
464 | // Ensure that producer prepared all tensor's data |
465 | std::string s_ipc_event_handle = |
466 | THPStorage_bytesAsHandleString(_event_handle); |
467 | auto ipc_event_handle = reinterpret_cast<const cudaIpcEventHandle_t*>( |
468 | s_ipc_event_handle.c_str()); |
469 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
470 | cudaEvent_t event; |
471 | cudaIpcOpenEventHandle(&event, *ipc_event_handle); |
472 | C10_CUDA_CHECK( |
473 | cudaStreamWaitEvent(c10::cuda::getCurrentCUDAStream(device), event, 0)); |
474 | } |
475 | |
476 | std::string s_handle = THPStorage_bytesAsHandleString(_handle); |
477 | std::shared_ptr<void> basePtr = |
478 | c10::cuda::CUDACachingAllocator::getIpcDevPtr(s_handle); |
479 | |
480 | // Offset the basePtr to reconstruct the real storage |
481 | // devPtr = basePtr + storage_offset |
482 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
483 | void* devPtr = basePtr.get(); |
484 | devPtr = (char*)devPtr + storage_offset_bytes; |
485 | |
486 | std::string ref_counter_handle = PyBytes_AS_STRING(_ref_counter); |
487 | ptrdiff_t ref_counter_offset = |
488 | (ptrdiff_t)THPUtils_unpackLong(_ref_counter_offset); |
489 | |
490 | struct IpcDeleterContext { |
491 | std::string ref_counter_handle; |
492 | ptrdiff_t ref_counter_offset; |
493 | int64_t device; |
494 | torch::CudaIPCReceivedData received_data; |
495 | }; |
496 | |
497 | auto ctx = std::make_unique<IpcDeleterContext>(); |
498 | ctx->ref_counter_handle = std::move(ref_counter_handle); |
499 | ctx->ref_counter_offset = ref_counter_offset; |
500 | ctx->device = device; |
501 | ctx->received_data.shared_ptr_ = std::move(basePtr); |
502 | |
503 | auto cur_device = at::cuda::current_device(); |
504 | c10::DataPtr data_ptr( |
505 | devPtr, |
506 | ctx.release(), |
507 | +[](void* ctx_) { |
508 | std::unique_ptr<IpcDeleterContext> ctx( |
509 | static_cast<IpcDeleterContext*>(ctx_)); |
510 | ctx->received_data.shared_ptr_.reset(); |
511 | |
512 | // Sync default stream to make sure all operations related to the |
513 | // storage is finished (otherwise another process may reuse memory and |
514 | // corrupt data) |
515 | |
516 | // Ideally all shared memory reference counting could be replaced by |
517 | // sending untriggered CUDA event from the producer to consumer and |
518 | // using this event as the criteria of memory release. However, CUDA |
519 | // (atm 10.1) does not support the creation of untriggered events and |
520 | // performance impact of having thousands of shared events is unknown. |
521 | |
522 | // TODO: Instead of cudaStreamSynchronize it is possible to add Stream |
523 | // Callback and release counter inside of it (need to check performance |
524 | // impact) |
525 | at::cuda::stream_synchronize( |
526 | c10::cuda::getCurrentCUDAStream(ctx->device)); |
527 | |
528 | // We don't want to break existing code, so resource deletion is best |
529 | // effort basis. Exception expected if producer process terminated |
530 | // before consumer released data. |
531 | int flags = |
532 | at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_NOCREATE; |
533 | try { |
534 | auto sptr = at::RefcountedMapAllocator::makeDataPtr( |
535 | ctx->ref_counter_handle.c_str(), |
536 | flags, |
537 | sizeof(int64_t) * torch::CUDA_IPC_REF_COUNTER_FILE_SIZE, |
538 | nullptr); |
539 | *(static_cast<int64_t*>(sptr.get()) + ctx->ref_counter_offset) -= 1; |
540 | } catch (c10::Error& err) { |
541 | // Already warned inside of producer process |
542 | } |
543 | }, |
544 | at::Device(at::DeviceType::CUDA, cur_device)); |
545 | |
546 | auto base = c10::make_intrusive<at::StorageImpl>( |
547 | c10::StorageImpl::use_byte_size_t(), |
548 | storage_size, |
549 | std::move(data_ptr), |
550 | /*allocator=*/nullptr, |
551 | /*resizable=*/false); |
552 | |
553 | base->set_resizable(false); |
554 | base->set_received_cuda(true); |
555 | |
556 | return THPStorage_New(std::move(base)); |
557 | #else |
558 | TORCH_CHECK(false, "CUDA is not available" ); |
559 | #endif |
560 | END_HANDLE_TH_ERRORS |
561 | } |
562 | |
563 | // Returns an object that holds a "weak" pointer to the c10::StorageImpl. This |
564 | // pointer keeps the c10::StorageImpl struct live, but does not retain the data |
565 | // pointer. |
566 | // |
567 | // NB: This does NOT preserve object identity when you call it multiple times |
568 | static PyObject* THPStorage_weakRef(PyObject* _self, PyObject* args) { |
569 | HANDLE_TH_ERRORS |
570 | auto self = (THPStorage*)_self; |
571 | c10::StorageImpl* storage = self->cdata; |
572 | return PyLong_FromVoidPtr(c10::raw::intrusive_ptr::make_weak(storage)); |
573 | END_HANDLE_TH_ERRORS |
574 | } |
575 | |
576 | PyObject* THPStorage_newWithWeakPtr(PyObject* _unused, PyObject* arg) { |
577 | HANDLE_TH_ERRORS |
578 | THPUtils_assert( |
579 | THPUtils_checkLong(arg), "_new_with_weak_ptr(): arg must be an 'int'" ); |
580 | c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg); |
581 | if (auto* storage = c10::raw::weak_intrusive_ptr::lock(weak_storage)) { |
582 | return THPStorage_New( |
583 | c10::intrusive_ptr<c10::StorageImpl>::reclaim(storage)); |
584 | } |
585 | Py_RETURN_NONE; |
586 | END_HANDLE_TH_ERRORS |
587 | } |
588 | |
589 | PyObject* THPStorage_freeWeakRef(PyObject* _unused, PyObject* arg) { |
590 | HANDLE_TH_ERRORS |
591 | if (arg == Py_None) { |
592 | Py_RETURN_NONE; |
593 | } |
594 | THPUtils_assert( |
595 | THPUtils_checkLong(arg), "_free_weak_ref(): arg must be an 'int'" ); |
596 | c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg); |
597 | c10::raw::weak_intrusive_ptr::decref(weak_storage); |
598 | |
599 | Py_RETURN_NONE; |
600 | END_HANDLE_TH_ERRORS |
601 | } |
602 | |
603 | PyObject* THPStorage_expired(PyObject* _unused, PyObject* arg) { |
604 | HANDLE_TH_ERRORS |
605 | THPUtils_assert(THPUtils_checkLong(arg), "_expired(): arg must be an 'int'" ); |
606 | c10::StorageImpl* weak_storage = (c10::StorageImpl*)PyLong_AsVoidPtr(arg); |
607 | return PyBool_FromLong( |
608 | c10::raw::weak_intrusive_ptr::use_count(weak_storage) == 0); |
609 | END_HANDLE_TH_ERRORS |
610 | } |
611 | |
612 | PyObject* THPStorage_sharedFd(PyObject* _self, PyObject* noargs) { |
613 | HANDLE_TH_ERRORS |
614 | auto self = (THPStorage*)_self; |
615 | at::MapAllocator* ctx = nullptr; |
616 | if (self->cdata->device_type() == at::kCPU) { |
617 | c10::StorageImpl* storage = self->cdata; |
618 | ctx = at::MapAllocator::fromDataPtr(storage->data_ptr()); |
619 | } |
620 | |
621 | THPUtils_assert(ctx, "couldn't retrieve a shared file descriptor" ); |
622 | return THPUtils_packInt32(ctx->fd()); |
623 | END_HANDLE_TH_ERRORS |
624 | } |
625 | |
626 | PyObject* THPStorage_isShared(PyObject* _self, PyObject* noargs) { |
627 | auto self = (THPStorage*)_self; |
628 | if (self->cdata->device_type() == at::kCUDA) { |
629 | Py_RETURN_TRUE; |
630 | } |
631 | if (at::MapAllocator::fromDataPtr(self->cdata->data_ptr()) || |
632 | THManagedMapAllocator::fromDataPtr(self->cdata->data_ptr())) { |
633 | Py_RETURN_TRUE; |
634 | } else { |
635 | Py_RETURN_FALSE; |
636 | } |
637 | } |
638 | |
639 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
640 | static PyMethodDef THPStorage_sharingMethods[] = { |
641 | {"_new_with_weak_ptr" , |
642 | THPStorage_newWithWeakPtr, |
643 | METH_O | METH_CLASS, |
644 | nullptr}, |
645 | {"_share_cuda_" , THPStorage_shareCuda, METH_NOARGS, nullptr}, |
646 | {"_new_shared_cuda" , |
647 | THPStorage_newSharedCuda, |
648 | METH_VARARGS | METH_STATIC, |
649 | nullptr}, |
650 | {"_release_ipc_counter_cuda" , |
651 | THPStorage_releaseIPCCounter, |
652 | METH_VARARGS | METH_STATIC, |
653 | nullptr}, |
654 | {"_share_fd_cpu_" , THPStorage_shareFd, METH_NOARGS, nullptr}, |
655 | {"_new_shared_fd_cpu" , |
656 | THPStorage_newSharedFd, |
657 | METH_VARARGS | METH_STATIC, |
658 | nullptr}, |
659 | {"_new_using_fd_cpu" , |
660 | THPStorage_pyNewFdStorage, |
661 | METH_VARARGS | METH_STATIC, |
662 | nullptr}, |
663 | {"_share_filename_cpu_" , THPStorage_shareFilename, METH_NOARGS, nullptr}, |
664 | {"_new_shared_filename_cpu" , |
665 | THPStorage_newSharedFilename, |
666 | METH_VARARGS | METH_STATIC, |
667 | nullptr}, |
668 | {"_new_using_filename_cpu" , |
669 | THPStorage_pyNewFilenameStorage, |
670 | METH_VARARGS | METH_STATIC, |
671 | nullptr}, |
672 | {"_weak_ref" , THPStorage_weakRef, METH_NOARGS, nullptr}, |
673 | {"_free_weak_ref" , THPStorage_freeWeakRef, METH_O | METH_STATIC, nullptr}, |
674 | {"_expired" , THPStorage_expired, METH_O | METH_STATIC, nullptr}, |
675 | {"_shared_decref" , THPStorage_sharedDecref, METH_NOARGS, nullptr}, |
676 | {"_shared_incref" , THPStorage_sharedIncref, METH_NOARGS, nullptr}, |
677 | {"_get_shared_fd" , THPStorage_sharedFd, METH_NOARGS, nullptr}, |
678 | {"is_shared" , THPStorage_isShared, METH_NOARGS, nullptr}, |
679 | {nullptr}}; |
680 | |
681 | PyMethodDef* THPStorage_getSharingMethods() { |
682 | return THPStorage_sharingMethods; |
683 | } |
684 | |