1#include <ATen/MapAllocator.h>
2#include <torch/csrc/CudaIPCTypes.h>
3#include <map>
4#include <mutex>
5#include <random>
6#include <string>
7
8namespace torch {
9
10namespace {
11
12void warnProducerTerminatedBeforeSharedTensorsReleased() {
13 static bool warned = false;
14 if (!warned) {
15 LOG(WARNING)
16 << "Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]";
17 warned = true;
18 }
19}
20
21struct CudaIPCGlobalEntities {
22 // This class is used as a singleton (see cuda_ipc_global_entities)
23 // This variable is used to track its lifetime to avoid accessing it
24 // after it was destroyed which would lead to segmentation faults
25 // Note that a trvial type is used which doesn't suffer from construction
26 // and destruction order issues
27 static bool alive;
28
29 std::mutex ref_counters_mutex_;
30 std::atomic<int64_t> sync_events_used_{0};
31 std::map<std::string, std::shared_ptr<CudaIPCRefCountersFile>>
32 ref_counters_files_;
33 std::shared_ptr<CudaIPCRefCountersFile> next_available_ref_counters_file_;
34 CudaIPCSentDataLimbo CudaIPCSentDataLimbo_;
35 CudaIPCGlobalEntities() {
36 alive = true;
37 }
38 ~CudaIPCGlobalEntities() {
39 CudaIPCSentDataLimbo_.collect();
40 safe_clean_current_file();
41 if (next_available_ref_counters_file_) {
42 warnProducerTerminatedBeforeSharedTensorsReleased();
43 }
44 alive = false;
45 }
46 void safe_clean_current_file() {
47 std::lock_guard<std::mutex> lock(ref_counters_mutex_);
48 if (next_available_ref_counters_file_ &&
49 next_available_ref_counters_file_->offsets_in_use() == 0) {
50 ref_counters_files_.erase(next_available_ref_counters_file_->handle());
51 next_available_ref_counters_file_.reset();
52 }
53 }
54};
55
56bool CudaIPCGlobalEntities::alive = false;
57CudaIPCGlobalEntities cuda_ipc_global_entities;
58
59CudaIPCSentDataLimbo::~CudaIPCSentDataLimbo() {
60 collect();
61 if (size() > 0) {
62 warnProducerTerminatedBeforeSharedTensorsReleased();
63 }
64}
65
66bool CudaIPCSentDataLimbo::collect() {
67 bool freed_memory = false;
68 std::vector<std::unique_ptr<CudaIPCSentData>> reset_blocks;
69 { // Begin critical section to modify shared blocks
70 std::lock_guard<std::mutex> lock(limbo_mutex_);
71 std::vector<std::unique_ptr<CudaIPCSentData>> kept_blocks;
72 for (auto& sd : shared_blocks_) {
73 if (sd->counter_value() > 0) {
74 kept_blocks.push_back(std::move(sd));
75 } else {
76 freed_memory = true;
77 reset_blocks.push_back(std::move(sd));
78 }
79 }
80 shared_blocks_ = std::move(kept_blocks);
81 }
82 // Need to reset blocks out of the critical section here, otherwise it
83 // deadlocks.
84 for (auto& sd : reset_blocks) {
85 sd.reset();
86 }
87 return freed_memory;
88}
89
90void CudaIPCSentDataLimbo::add(std::unique_ptr<CudaIPCSentData> shared_block) {
91 std::lock_guard<std::mutex> lock(limbo_mutex_);
92 static bool warned = false;
93 if (shared_blocks_.size() > CUDA_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO &&
94 !warned) {
95 LOG(WARNING)
96 << "Producer process tried to deallocate over "
97 << CUDA_IPC_WARN_AFTER_X_BLOCKS_IN_LIMBO
98 << " memory blocks referred by consumer processes. Deallocation might be significantly slowed down. "
99 << "We assume it will never going to be the case, but if it is, please file but to https://github.com/pytorch/pytorch";
100 warned = true;
101 }
102 shared_blocks_.push_back(std::move(shared_block));
103}
104
105uint64_t CudaIPCSentDataLimbo::size() {
106 std::lock_guard<std::mutex> lock(limbo_mutex_);
107 return shared_blocks_.size();
108}
109
110void CudaIPCSentDataDelete(void* ptr) {
111 std::unique_ptr<CudaIPCSentData> sent_data(
112 static_cast<CudaIPCSentData*>(ptr));
113 if (!CudaIPCGlobalEntities::alive) {
114 return;
115 }
116 if (sent_data->counter_value() > 0) {
117 cuda_ipc_global_entities.CudaIPCSentDataLimbo_.add(std::move(sent_data));
118 }
119 cuda_ipc_global_entities.CudaIPCSentDataLimbo_.collect();
120}
121
122void ReturnRefCounter(const std::string& handle, uint64_t offset /* unused */) {
123 if (!CudaIPCGlobalEntities::alive) {
124 return;
125 }
126 std::lock_guard<std::mutex> lock(
127 cuda_ipc_global_entities.ref_counters_mutex_);
128 auto& map = cuda_ipc_global_entities.ref_counters_files_;
129 auto it = map.find(handle);
130 if (it != map.end()) {
131 it->second->return_offset(offset);
132 if (it->second->offsets_in_use() == 0 && !it->second->have_offsets()) {
133 map.erase(handle);
134 }
135 }
136}
137
138} // namespace
139
140CudaIPCSentData::CudaIPCSentData(
141 std::string handle,
142 int64_t offset,
143 int64_t* counter_ptr,
144 at::Device device)
145 : handle_(std::move(handle)),
146 offset_(offset),
147 counter_ptr_(counter_ptr),
148 original_ptr_(),
149 device_(device) {
150#if !defined(USE_ROCM)
151 // CUDA have the unofficial limit on the number of recorded blocking
152 // interprocess events, to prevent using of all events, we are switching to
153 // StreamSync before limit reached.
154 //
155 // ```python
156 // import torch
157 // a = [ torch.cuda.Event(
158 // enable_timing=False, blocking=True, interprocess=True) for i in
159 // range(30000) ]
160 // [i.record() for i in a]
161 // ```
162 //
163 if (cuda_ipc_global_entities.sync_events_used_.load() <
164 CUDA_IPC_MAXIMUM_EVENTS_TO_USE) {
165 // TODO: More efficient would be to create event inside of main thread (at
166 // the moment of the queue.put). The reason this is more efficient is
167 // because the main thread may have queued extra work on the stream, which
168 // this event will consequently wait for (uselessly).
169 cuda_ipc_global_entities.sync_events_used_++;
170 C10_CUDA_CHECK(cudaEventCreateWithFlags(
171 &event_,
172 cudaEventDisableTiming | cudaEventInterprocess |
173 cudaEventBlockingSync));
174 C10_CUDA_CHECK(cudaEventRecord(
175 event_, c10::cuda::getCurrentCUDAStream(device.index())));
176 event_sync_required_ = true;
177 } else {
178 auto stream = c10::cuda::getCurrentCUDAStream(device.index());
179 at::cuda::stream_synchronize(stream);
180 event_ = nullptr;
181 event_sync_required_ = false;
182 }
183#else
184 // cuIpcGetEventHandle with HIP is not supported, so we have to sync
185 // stream instead of passing event
186 auto stream = c10::cuda::getCurrentCUDAStream(device.index());
187 at::cuda::stream_synchronize(stream);
188 event_sync_required_ = false;
189#endif
190}
191
192CudaIPCSentData::~CudaIPCSentData() {
193 ReturnRefCounter(handle_, offset_);
194#if !defined(USE_ROCM)
195 try {
196 if (event_sync_required_) {
197 at::cuda::CUDAGuard device_guard(device_.index());
198 C10_CUDA_CHECK(cudaEventDestroy(event_));
199 if (!CudaIPCGlobalEntities::alive) {
200 return;
201 }
202 cuda_ipc_global_entities.sync_events_used_--;
203 }
204 } catch (...) { /* No throw */
205 }
206#endif
207}
208
209int64_t CudaIPCSentData::counter_value() {
210 return *counter_ptr_;
211}
212
213at::DataPtr GetNewRefCountedSentData(void* data, at::Device device) {
214 {
215 std::lock_guard<std::mutex> lock(
216 cuda_ipc_global_entities.ref_counters_mutex_);
217 if (!cuda_ipc_global_entities.next_available_ref_counters_file_) {
218 std::string ref_counter_handle = at::NewProcessWideShmHandle();
219
220 int flags =
221 at::ALLOCATOR_MAPPED_SHAREDMEM | at::ALLOCATOR_MAPPED_EXCLUSIVE;
222 at::DataPtr sptr = at::RefcountedMapAllocator::makeDataPtr(
223 ref_counter_handle.c_str(),
224 flags,
225 sizeof(int64_t) * CUDA_IPC_REF_COUNTER_FILE_SIZE,
226 nullptr);
227 auto rc = std::make_shared<CudaIPCRefCountersFile>(
228 ref_counter_handle, CUDA_IPC_REF_COUNTER_FILE_SIZE, std::move(sptr));
229 cuda_ipc_global_entities.ref_counters_files_[ref_counter_handle] = rc;
230 cuda_ipc_global_entities.next_available_ref_counters_file_ = rc;
231 }
232 }
233 cuda_ipc_global_entities.next_available_ref_counters_file_->set_counter(1);
234 auto sent_data = new CudaIPCSentData(
235 cuda_ipc_global_entities.next_available_ref_counters_file_->handle(),
236 cuda_ipc_global_entities.next_available_ref_counters_file_->get_offset(),
237 cuda_ipc_global_entities.next_available_ref_counters_file_->counter_ptr(),
238 device);
239
240 cuda_ipc_global_entities.next_available_ref_counters_file_->rotate_offset();
241 if (!cuda_ipc_global_entities.next_available_ref_counters_file_
242 ->have_offsets()) {
243 cuda_ipc_global_entities.next_available_ref_counters_file_.reset();
244 }
245 return at::DataPtr(data, sent_data, CudaIPCSentDataDelete, device);
246}
247
248bool CudaIPCCollect() {
249 if (!CudaIPCGlobalEntities::alive) {
250 return true;
251 }
252 bool freed_memory = cuda_ipc_global_entities.CudaIPCSentDataLimbo_.collect();
253 if (cuda_ipc_global_entities.CudaIPCSentDataLimbo_.size() == 0) {
254 cuda_ipc_global_entities.safe_clean_current_file();
255 }
256 return freed_memory;
257}
258
259} // namespace torch
260
261namespace c10 {
262namespace {
263REGISTER_FREE_MEMORY_CALLBACK("cuda_ipc_collect", CudaIPCCollectCallback);
264}
265} // namespace c10
266