1 | #include <torch/csrc/distributed/rpc/rref_context.h> |
2 | #include <torch/csrc/distributed/rpc/rref_proto.h> |
3 | #include <torch/csrc/distributed/rpc/utils.h> |
4 | |
5 | #include <sstream> |
6 | |
7 | namespace torch { |
8 | namespace distributed { |
9 | namespace rpc { |
10 | |
11 | thread_local std::vector<std::shared_ptr<RRefContext::PendingUserState>> |
12 | RRefContext::userTable_; |
13 | thread_local bool RRefContext::recording_ = false; |
14 | |
15 | namespace callback { |
16 | void confirmPendingUser( |
17 | const JitFuture& jitFuture, |
18 | const ForkId& expectedForkId) { |
19 | if (!jitFuture.hasError()) { |
20 | auto msgPtr = jitFuture.constValue().toCustomClass<Message>(); |
21 | auto msgType = msgPtr->type(); |
22 | auto rpc = deserializeResponse(*msgPtr, msgType); |
23 | auto& rr = dynamic_cast<RemoteRet&>(*rpc); |
24 | TORCH_INTERNAL_ASSERT(rr.forkId() == expectedForkId); |
25 | } else { |
26 | // Handle errors, such as timeouts, by invoking the error handler on the |
27 | // rref. |
28 | // Note [Best Effort Error handling for Remote calls]: |
29 | // When remote calls initiated by rpc.remote() fail, such as with a timeout |
30 | // error, we take a best-effort approach to error handling. We handle errors |
31 | // when callbacks corresponding to the remote call run, and set the error |
32 | // information on the RRef. If the RRef has not been used by the application |
33 | // before this process (such as to_here or fork call), then future uses of |
34 | // the RRef will appropriately raise errors. However, it is possible that |
35 | // the user application will use the RRef before the errors are handled. In |
36 | // this case, errors may not be raised as they have not yet been handled. |
37 | auto rref_ptr = RRefContext::getInstance().getPendingUser(expectedForkId); |
38 | auto errorType = getRPCErrorType(jitFuture); |
39 | rref_ptr->handleError(errorType, jitFuture); |
40 | } |
41 | RRefContext::getInstance().delPendingUser(expectedForkId); |
42 | } |
43 | |
44 | c10::intrusive_ptr<RRef> finishCreatingOwnerRRef( |
45 | const JitFuture& jitFuture, |
46 | const RRefId& rrefId) { |
47 | if (jitFuture.hasError()) { |
48 | auto& ctx = RRefContext::getInstance(); |
49 | // We expect to run this callback only after the OwnerRRef has been created, |
50 | // since this is only invoked when sending to self. |
51 | auto rref_ptr = |
52 | fromRRefInterface(ctx.getOwnerRRef(rrefId, /* foreCreated */ true) |
53 | ->constValue() |
54 | .toRRef()); |
55 | auto errorType = getRPCErrorType(jitFuture); |
56 | rref_ptr->handleError(errorType, jitFuture); |
57 | // OwnerRRefs do not have a forkId, so don't need to assert here. |
58 | auto deletedRRef = |
59 | ctx.delForkOfOwner(rref_ptr->rrefId(), rref_ptr->rrefId()); |
60 | return deletedRRef; |
61 | } else { |
62 | auto msgPtr = jitFuture.constValue().toCustomClass<Message>(); |
63 | auto msgType = msgPtr->type(); |
64 | auto rpc = deserializeResponse(*msgPtr, msgType); |
65 | auto& rr = dynamic_cast<RemoteRet&>(*rpc); |
66 | TORCH_INTERNAL_ASSERT( |
67 | rr.rrefId() == rr.forkId(), |
68 | "Expecting an OwnerRRef as RemoteRet but got a fork." ); |
69 | auto& ctx = RRefContext::getInstance(); |
70 | auto deletedRRef = ctx.delForkOfOwner(rr.rrefId(), rr.rrefId()); |
71 | return deletedRRef; |
72 | } |
73 | } |
74 | |
75 | } // namespace callback |
76 | |
77 | // Keys for RRef-related debug information. |
78 | const std::string kNumOwnerRRefs = "num_owner_rrefs" ; |
79 | const std::string kNumPendingFutures = "num_pending_futures" ; |
80 | const std::string kNumPendingUsers = "num_pending_users" ; |
81 | const std::string kNumForks = "num_forks" ; |
82 | |
83 | RRefContext& RRefContext::getInstance() { |
84 | // Leaky singleton to avoid module destructor races. |
85 | static RRefContext* context = new RRefContext(RpcAgent::getCurrentRpcAgent()); |
86 | return *context; |
87 | } |
88 | |
89 | std::vector<c10::intrusive_ptr<RRef>> RRefContext::destroyInstance( |
90 | bool ignoreRRefLeak) { |
91 | auto& ctx = RRefContext::getInstance(); |
92 | { |
93 | std::lock_guard<std::mutex> lock(ctx.destroyedMutex_); |
94 | ctx.destroyed_ = true; |
95 | } |
96 | ctx.checkRRefLeaks(ignoreRRefLeak); |
97 | std::vector<c10::intrusive_ptr<RRef>> deletedRRefs; |
98 | for (auto& entry : ctx.owners_) { |
99 | auto rref = entry.second; |
100 | if (rref->isPyObj()) { |
101 | deletedRRefs.emplace_back(std::move(rref)); |
102 | } |
103 | } |
104 | ctx.owners_.clear(); |
105 | ctx.pendingOwners_.clear(); |
106 | return deletedRRefs; |
107 | } |
108 | |
109 | void RRefContext::handleException(const JitFuture& jitFuture) { |
110 | if (jitFuture.hasError()) { |
111 | auto errMsg = jitFuture.tryRetrieveErrorMessage(); |
112 | VLOG(1) << "Got exception: " << errMsg; |
113 | TORCH_CHECK(false, errMsg); |
114 | } |
115 | } |
116 | |
117 | void RRefContext::handleExceptionSilent(const JitFuture& jitFuture) { |
118 | if (jitFuture.hasError()) { |
119 | auto errMsg = jitFuture.tryRetrieveErrorMessage(); |
120 | VLOG(1) << "Got exception: " << errMsg; |
121 | TORCH_CHECK_MSG(false, errMsg); |
122 | } |
123 | } |
124 | |
125 | RRefContext::RRefContext(std::shared_ptr<RpcAgent> agent) |
126 | : agent_(std::move(agent)) {} |
127 | |
128 | RRefContext::~RRefContext() { |
129 | if (!owners_.empty()) { |
130 | VLOG(1) << "Destructing RRefContext with non-empty OwnerRRef set. " |
131 | << "This would likely cause Python deref error. " |
132 | << "Make sure destroyInstance() is invoked before destruction." ; |
133 | } |
134 | } |
135 | |
136 | std::unordered_map<std::string, std::string> RRefContext::getDebugInfo() { |
137 | std::unordered_map<std::string, std::string> info; |
138 | std::unique_lock<std::mutex> lock(mutex_); |
139 | auto ownerSize = owners_.size(); |
140 | auto numPendingUsers = pendingUsers_.size(); |
141 | int numForks = 0; |
142 | for (const auto& owner : forks_) { |
143 | numForks += owner.second.size(); |
144 | } |
145 | lock.unlock(); |
146 | info[kNumOwnerRRefs] = c10::to_string(ownerSize); |
147 | info[kNumPendingFutures] = c10::to_string(numPendingFutures_.load()); |
148 | info[kNumPendingUsers] = c10::to_string(numPendingUsers); |
149 | info[kNumForks] = c10::to_string(numForks); |
150 | return info; |
151 | } |
152 | |
153 | void RRefContext::checkRRefLeaks(bool ignoreRRefLeak) { |
154 | if (!forks_.empty()) { |
155 | std::stringstream ss; |
156 | for (auto& entry : forks_) { |
157 | const RRefId& rrefId = entry.first; |
158 | for (const auto& forkId : entry.second) { |
159 | ss << "Leaking RRef " << rrefId << " with fork Id " << forkId |
160 | << std::endl; |
161 | } |
162 | } |
163 | |
164 | LOG(WARNING) |
165 | << "Detected RRef Leaks during shutdown. This usually " |
166 | << "occurs when the application code still holds references to RRef " |
167 | << "instances when calling shutdown(). If the program has " |
168 | << "completed correctly and the process is exiting, it is OK to " |
169 | << "ignore these leaks. However, if you program will keep running " |
170 | << "after this, these leaks could result in memory leaks on RRef " |
171 | << "owners. Please make sure all RRefs are out of scope and Python " |
172 | << "GC has deleted them before calling shutdown(): \n" |
173 | << ss.str(); |
174 | if (!ignoreRRefLeak) { |
175 | TORCH_CHECK(false, ss.str()); |
176 | } |
177 | } |
178 | } |
179 | |
180 | c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef( |
181 | worker_id_t ownerId, |
182 | const TypePtr& type) { |
183 | TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner." ); |
184 | // Explicitly creating rrefId before forkId to make sure the order is |
185 | // deterministic, as the argument evaluation order is system and compiler |
186 | // dependent. |
187 | const auto rrefId = genGloballyUniqueId(); |
188 | const auto forkId = genGloballyUniqueId(); |
189 | return createUserRRef(ownerId, rrefId, forkId, type); |
190 | } |
191 | |
192 | c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef( |
193 | worker_id_t ownerId, |
194 | const RRefId& rrefId, |
195 | const ForkId& forkId, |
196 | const TypePtr& type) { |
197 | TORCH_CHECK(ownerId != getWorkerId(), "RRef owner cannot create user RRef." ); |
198 | // RRefContext does not track user RRefs, it will be destructed when there |
199 | // is no shared_ptrs pointing to it. |
200 | // |
201 | // NB: cannot use make_shared here as the constructor of UserRRef is private. |
202 | // NB: This UserRRef has not been confirmed by the owner yet. This function's |
203 | // call site is responsible for adding this UserRRef to pendingUsers_. |
204 | // Currently, there are two call sites. |
205 | // (1) The creator user in python_functions.cpp |
206 | // (2) The callee user in RRefContext::notifyOwnerAndParentOfFork. |
207 | // |
208 | // The reason for not adding the pending user here is to put addPendingUser() |
209 | // close to where the RPC occurs, and it is more clear to pair it with |
210 | // deletePendingUser() in the response callback at the call site. |
211 | return c10::make_intrusive<UserRRef>(ownerId, rrefId, forkId, type); |
212 | } |
213 | |
214 | void RRefContext::delUser( |
215 | const worker_id_t owner, |
216 | const RRefId& rrefId, |
217 | const ForkId& forkId) { |
218 | { |
219 | std::lock_guard<std::mutex> lock(destroyedMutex_); |
220 | if (!destroyed_) { |
221 | // Sending an RRefUserDelete causes the receiver to run delForkOfOwner, |
222 | // which is now idempotent. See the comment at RRefContext::delForkOfOwner |
223 | // for more details. |
224 | ++numPendingFutures_; |
225 | auto jitFuture = agent_->sendWithRetries( |
226 | agent_->getWorkerInfo(owner), |
227 | RRefUserDelete(rrefId, forkId).toMessage()); |
228 | |
229 | jitFuture->addCallback([this](JitFuture& future) { |
230 | handleExceptionSilent(future); |
231 | --numPendingFutures_; |
232 | }); |
233 | } |
234 | } |
235 | |
236 | std::lock_guard<std::mutex> lock(mutex_); |
237 | confirmedUsers_.erase(forkId); |
238 | } |
239 | |
240 | void RRefContext::delAllUsersAndUnforkedOwners( |
241 | std::chrono::milliseconds timeoutMillis) { |
242 | // First, wait for all pending UserRRefs to be confirmed, |
243 | // one kind is pendingUsers_, which are shared from Owner, |
244 | // the other kind pendingChildren_, which are shared from another User. |
245 | std::unordered_map<ForkId, c10::weak_intrusive_ptr<RRef>, ForkId::Hash> |
246 | tempConfirmedUsers; |
247 | { |
248 | std::unique_lock<std::mutex> lock(mutex_); |
249 | bool noPending = deleteAllUsersCV_.wait_for(lock, timeoutMillis, [this]() { |
250 | return pendingUsers_.empty() && pendingChildren_.empty(); |
251 | }); |
252 | if (!noPending) { |
253 | LOG(ERROR) |
254 | << "Timed out waiting for pending UserRRefs to be confirmed by owner and parent." ; |
255 | } |
256 | tempConfirmedUsers.swap(confirmedUsers_); |
257 | } |
258 | |
259 | // Start sending UserRRef delete messages, after all pendings are confirmed. |
260 | // Note, there should be no new forkings in between, because it's assumed that |
261 | // this utility is called during graceful shutdown, where no new user RPCs can |
262 | // be initiaited anymore. |
263 | for (const auto& user : tempConfirmedUsers) { |
264 | c10::intrusive_ptr<RRef> rref_ptr = user.second.lock(); |
265 | if (!rref_ptr) { |
266 | continue; |
267 | } |
268 | // tryDel() below will re-acquire lock, lock must be released here. |
269 | rref_ptr->tryDel(); |
270 | } |
271 | |
272 | // If an rref in the owners_ map has never been forked, we will never get a |
273 | // corresponding message from the forking node(s) telling us to delete the |
274 | // RRef. Hence we delete the RRef here. This can occur when a remote call is |
275 | // sent to self and times out. |
276 | { |
277 | std::unique_lock<std::mutex> lock(mutex_); |
278 | std::vector<RRefId> unforkedOwners; |
279 | for (const auto& it : owners_) { |
280 | auto rrefId = it.first; |
281 | if (forks_.find(rrefId) == forks_.end()) { |
282 | // Successful fork of owner was never processed. |
283 | unforkedOwners.push_back(rrefId); |
284 | } |
285 | } |
286 | for (auto& rrefId : unforkedOwners) { |
287 | LOG(INFO) << "Removing unforked OwnerRRef with RRefId: " << rrefId; |
288 | auto iter = owners_.find(rrefId); |
289 | TORCH_CHECK( |
290 | iter != owners_.end(), |
291 | c10::str("Did not find OwnerRRef with RRefId: " , rrefId)); |
292 | owners_.erase(iter); |
293 | } |
294 | } |
295 | // Wait for this node to process all delete UserRRef messages it may get for |
296 | // the OwnerRRefs that exist on this node. |
297 | { |
298 | std::unique_lock<std::mutex> lock(mutex_); |
299 | bool noOwner = deleteAllUsersCV_.wait_for( |
300 | lock, timeoutMillis, [this]() { return owners_.empty(); }); |
301 | if (!noOwner) { |
302 | LOG(ERROR) << "Timed out waiting for pending OwnerRRefs to be deleted." ; |
303 | } |
304 | } |
305 | } |
306 | |
307 | c10::intrusive_ptr<RRef> RRefContext::getOrCreateRRef( |
308 | const RRefForkData& rrefForkData, |
309 | const TypePtr& type) { |
310 | auto& ownerId = rrefForkData.ownerId_; |
311 | auto& rrefId = rrefForkData.rrefId_; |
312 | auto& forkId = rrefForkData.forkId_; |
313 | if (ownerId == getWorkerId()) { |
314 | return getOrCreateOwnerRRef(rrefId, type); |
315 | } else { |
316 | return createUserRRef(ownerId, rrefId, forkId, type); |
317 | } |
318 | } |
319 | |
320 | c10::intrusive_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef( |
321 | const RRefId& rrefId, |
322 | const TypePtr& type) { |
323 | std::lock_guard<std::mutex> lock(mutex_); |
324 | const auto iter = owners_.find(rrefId); |
325 | if (iter == owners_.end()) { |
326 | // Scenario (1) the first time this owner knows about this RRef |
327 | // |
328 | // NB: cannot use make_shared here as the constructor of OwnerRRef is |
329 | // private. |
330 | auto rref = c10::make_intrusive<OwnerRRef>( |
331 | getWorkerId(), rrefId, type, agent_->getDevices()); |
332 | owners_[rref->rrefId()] = rref; |
333 | const auto pendingOwnerIter = pendingOwners_.find(rrefId); |
334 | if (pendingOwnerIter != pendingOwners_.end()) { |
335 | // cast to RRefInterface to hold it into IValue |
336 | auto rrefPtr = fromOwnerRRef(rref); |
337 | pendingOwnerIter->second->markCompleted(IValue(rrefPtr)); |
338 | pendingOwners_.erase(pendingOwnerIter); |
339 | } |
340 | return rref; |
341 | } else { |
342 | // Scenario (2) retrieving an existing RRef |
343 | auto ownerRRef = fromRRefInterface(iter->second); |
344 | // Now double check if the two types match |
345 | // |
346 | // Why we are special casing the check for tensor type here? |
347 | // this is because tensor types might get specialized on tensors when |
348 | // we pass inputs to the function, i.e. TensorType can filled with |
349 | // specific shape info, requires_grad info, etc. so the OwerRRef we |
350 | // found might already have those infos, but the `type` we passed in |
351 | // here is a plain TensorType, they are not equal relationship: |
352 | // specialized TensorType <: plain TensorType |
353 | // |
354 | // In RPC we don't care the difference as we ser/de with just the |
355 | // plain TensorType. This is not a issue for UserRRef creation either, |
356 | // since Tensor can only get specialized with a previous run of local |
357 | // JIT function, and we shouldn't preserve the specialized SubTensorType |
358 | // information on other workers because it's only information only. |
359 | if (type->isSubtypeOf(*TensorType::get())) { |
360 | TORCH_INTERNAL_ASSERT( |
361 | ownerRRef->type()->isSubtypeOf(*TensorType::get()), |
362 | "Expect OwnerRRef to be a sub-type of TensorType, but got " , |
363 | ownerRRef->type()->repr_str()); |
364 | } else { |
365 | TORCH_INTERNAL_ASSERT( |
366 | *ownerRRef->type() == *type, |
367 | "OwnerRRef type is " , |
368 | ownerRRef->type()->repr_str(), |
369 | ", expected type is " , |
370 | type->repr_str()); |
371 | } |
372 | return ownerRRef; |
373 | } |
374 | } |
375 | |
376 | c10::intrusive_ptr<OwnerRRef> RRefContext::createOwnerRRef( |
377 | const TypePtr& type) { |
378 | // Don't add this OnwerRRef to the owners_ map yet, otherwise |
379 | // it will never be removed from there. Instead, only add it to the |
380 | // map in prepareChildFork, in case this local RRef is being passed |
381 | // to another worker. |
382 | return c10::make_intrusive<OwnerRRef>( |
383 | getWorkerId(), genGloballyUniqueId(), type, agent_->getDevices()); |
384 | } |
385 | |
386 | c10::intrusive_ptr<JitFuture> RRefContext::getOwnerRRef( |
387 | const RRefId& rrefId, |
388 | bool forceCreated) { |
389 | std::unique_lock<std::mutex> lock(mutex_); |
390 | const auto iter = owners_.find(rrefId); |
391 | if (iter == owners_.end()) { |
392 | if (forceCreated) { |
393 | TORCH_INTERNAL_ASSERT( |
394 | false, |
395 | c10::str("Expected OwnerRRef with id " , rrefId, " to be created." )); |
396 | } |
397 | // Scenario (1) RRef is used before it is created |
398 | const auto pendingOwnerIter = pendingOwners_.find(rrefId); |
399 | if (pendingOwnerIter == pendingOwners_.end()) { |
400 | // Note: The type passed into RRefType::create() does not matter here, as |
401 | // the future is marked as completed with the RRef of the specific type |
402 | // in getOrCreateOwnerRRef(). |
403 | // We need to set devices here, even if they won't be used by the value |
404 | // (an RRef object doesn't contain any tensors, it just provides means to |
405 | // retrieve them) because we need them to be propagated/ to child futures. |
406 | // This is silly and we should find a way to avoid this. |
407 | auto futureOwner = c10::make_intrusive<JitFuture>( |
408 | RRefType::create(c10::AnyType::get()), agent_->getDevices()); |
409 | pendingOwners_[rrefId] = futureOwner; |
410 | return futureOwner; |
411 | } else { |
412 | return pendingOwnerIter->second; |
413 | } |
414 | } else { |
415 | // Scenario (2) retrieving an existing RRef |
416 | // Marks IValue Future as completed with the RRef IValue. |
417 | auto owner = iter->second; |
418 | auto rrefPtr = fromOwnerRRef(owner); |
419 | |
420 | // We need to set devices here, even if they won't be used by the value (an |
421 | // RRef object doesn't contain any tensors, it just provides means to |
422 | // retrieve them) because we need them to be propagated/ to child futures. |
423 | // This is silly and we should find a way to avoid this. |
424 | auto futureOwner = c10::make_intrusive<JitFuture>( |
425 | RRefType::create(owner->type()), agent_->getDevices()); |
426 | futureOwner->markCompleted(IValue(rrefPtr)); |
427 | return futureOwner; |
428 | } |
429 | } |
430 | |
431 | RRefForkData RRefContext::prepareChildFork( |
432 | const c10::intrusive_ptr<RRef>& rref) { |
433 | // If we know that rref creation on the owner has timed out, raise it to the |
434 | // user here, otherwise continue with pickling. |
435 | |
436 | TORCH_CHECK( |
437 | !rref->getTimedOut(), |
438 | "RRef creation via rpc.remote() timed out, and it " |
439 | "is possible that the RRef on the owner node does not exist." ); |
440 | auto rrefForkData = rref->fork(); |
441 | if (rref->isOwner()) { |
442 | // Note [Early Fork Registration] |
443 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
444 | // If the parent (caller) is the owner, directly register the fork, instead |
445 | // of waiting for another RREF_FORK_REQUEST or RREF_CHILD_ACCEPT message. An |
446 | // Alternative is adding the fork when the callee user ACKs. However, before |
447 | // that, the owner still have to adds the OwnerRRef into some map to keep it |
448 | // alive (e.g., in pendingChildren_). Hence, adding the fork here or in the |
449 | // ACK does not making any difference but only add complexity. |
450 | // TODO: When adding failure retries and timeout, this fork needs to be |
451 | // deleted if the owner does not receive the ACK within the timeout. |
452 | addForkOfOwner(rrefForkData.rrefId_, rrefForkData.forkId_); |
453 | // ensure that this RRef is in the owners_ list to keep it alive. |
454 | // this is needed for OwnerRRefs that were created locally. |
455 | { |
456 | std::lock_guard<std::mutex> lock(mutex_); |
457 | owners_[rref->rrefId()] = rref; |
458 | } |
459 | } else { |
460 | // Note [Useful Phantom Fork ID for User to Owner Call] |
461 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
462 | // If the callee of dist.remote or dist.rpc is the owner of this RRef, the |
463 | // callee will not create a fork using this rrefForkData.forkId_, because |
464 | // the owner will only keep one `OwnerRRef` instance and will not create any |
465 | // `UserRRef` instances. However, this rrefForkData.forkId_ is still |
466 | // necessary, as the caller user needs to keep this `UserRRef` alive until |
467 | // it gets the ACK from the callee owner. Otherwise, the delete message |
468 | // could arrive at the owner before this dist.rpc or dist.remote call, which |
469 | // could potentially trigger the `OwnerRRef` to be deleted before running |
470 | // the user code. |
471 | addPendingChild(rrefForkData.forkId_, rref); |
472 | } |
473 | return rrefForkData; |
474 | } |
475 | |
476 | void RRefContext::notifyOwnerAndParentOfFork( |
477 | const ForkId& forkId, |
478 | worker_id_t parent, |
479 | const c10::intrusive_ptr<RRef>& rref) { |
480 | // Fork is shared from owner. |
481 | if (parent == rref->owner()) { |
482 | if (parent == agent_->getWorkerInfo().id_) { |
483 | // Owner sending RRef to self, remove the forkId as it was added during |
484 | // pickling |
485 | auto deletedRRef = delForkOfOwner(rref->rrefId(), forkId); |
486 | if (deletedRRef) { |
487 | TORCH_INTERNAL_ASSERT( |
488 | deletedRRef->rrefId() == rref->rrefId(), |
489 | "Deleting a fork of " , |
490 | rref->rrefId(), |
491 | " triggered deleting the OwnerRRef of " , |
492 | deletedRRef->rrefId()); |
493 | // NB: not necessary to reset deletedRRef as rref is another shared_ptr |
494 | // instance pointing to the same OwnerRRef. |
495 | } |
496 | } else { |
497 | // If the parent is the owner, this fork has already been added into the |
498 | // forks_ map when the owner sends the message to the callee user. |
499 | // Hence, it is not necessary to send another RREF_CHILD_ACCEPT or |
500 | // RREF_FORK_REQUEST back to the owner. See Note [Early Fork |
501 | // Registration]. |
502 | std::lock_guard<std::mutex> lock(mutex_); |
503 | addConfirmedUser(forkId, rref); |
504 | } |
505 | return; |
506 | } |
507 | |
508 | // Fork is shared from user. |
509 | if (rref->isOwner()) { |
510 | // See Note [Useful Phantom Fork ID for User to Owner Call] |
511 | // In this case, the owner is the caller, and it does not add the fork id |
512 | // into forks_. Because, there will be no real `UserRRef` associated |
513 | // with this fork ID. |
514 | ++numPendingFutures_; |
515 | auto jitFuture = agent_->sendWithRetries( |
516 | agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage()); |
517 | jitFuture->addCallback([this](JitFuture& future) { |
518 | handleExceptionSilent(future); |
519 | --numPendingFutures_; |
520 | }); |
521 | } else { |
522 | ++numPendingFutures_; |
523 | auto jitFuture = agent_->sendWithRetries( |
524 | agent_->getWorkerInfo(rref->owner()), |
525 | RRefForkRequest(rref->rrefId(), forkId).toMessage()); |
526 | |
527 | addPendingUser(forkId, rref); |
528 | |
529 | jitFuture->addCallback([this, forkId, parent](JitFuture& future) { |
530 | handleException(future); |
531 | this->finishForkRequest(forkId, parent); |
532 | // Decrease after calling finishForkRequest because, as that creates a new |
533 | // future, it might otherwise cause the count to briefly go to zero. |
534 | --numPendingFutures_; |
535 | }); |
536 | } |
537 | } |
538 | |
539 | void RRefContext::addPendingChild( |
540 | const ForkId& forkId, |
541 | const c10::intrusive_ptr<RRef>& rref) { |
542 | // see Note [Early Fork Registration] |
543 | // If the parent is the owner, it should directly add the child UserRRef as a |
544 | // fork. |
545 | TORCH_INTERNAL_ASSERT( |
546 | !rref->isOwner(), "OwnerRRef should not have a pending child." ); |
547 | std::lock_guard<std::mutex> lock(mutex_); |
548 | TORCH_INTERNAL_ASSERT( |
549 | pendingChildren_.find(forkId) == pendingChildren_.end(), |
550 | "Inconsistent states: attempt to add the same child fork twice." ); |
551 | pendingChildren_[forkId] = rref; |
552 | } |
553 | |
554 | void RRefContext::delPendingChild(const ForkId& forkId) { |
555 | c10::intrusive_ptr<RRef> deletedUser; |
556 | { |
557 | std::lock_guard<std::mutex> lock(mutex_); |
558 | auto iter = pendingChildren_.find(forkId); |
559 | // We first check whether the child exists in pendingChildren_. It's |
560 | // possible the child may have been removed by a previous send attempt, and |
561 | // this check (as opposed to an assertion here) ensures that messages that |
562 | // trigger this function are idempotent. |
563 | if (iter != pendingChildren_.end()) { |
564 | // Since this UserRRef is removed from the map, |
565 | // the refcount of this UserRRef could reach to 0, |
566 | // so the "destructor", `release_resources()`, might be called, |
567 | // in which the lock is acquired again. |
568 | // So it must be destructed with the lock released. |
569 | // Meet this constraint by creating a temporary pointer to increase the |
570 | // refcount, extending its lifetime until lock released. |
571 | deletedUser = iter->second; // Increase refcount. |
572 | pendingChildren_.erase(iter); // Decrease refcount. |
573 | } else { |
574 | LOG(INFO) << "Ignoring duplicate request to delete child UserRRef with " |
575 | << "ForkId = " << forkId; |
576 | } |
577 | } |
578 | deleteAllUsersCV_.notify_all(); |
579 | // The refcount of this UserRRef could reach to 0, |
580 | // so the "destructor", release_resources(), might be called, |
581 | // in which the lock is acquired again, |
582 | // so must destruct it with the lock released. |
583 | deletedUser.reset(); // Decrease refcount. |
584 | } |
585 | |
586 | void RRefContext::addPendingUser( |
587 | const ForkId& forkId, |
588 | const c10::intrusive_ptr<RRef>& rref) { |
589 | TORCH_INTERNAL_ASSERT( |
590 | !rref->isOwner(), "Attempt to add an OwnerRRef as a pending User." ); |
591 | |
592 | auto state = std::make_shared<PendingUserState>(rref); |
593 | if (recording_) { |
594 | // adding and waiting for pending users are guaranteed to be called from the |
595 | // same thread, but deleting pending users will be called from another |
596 | // thread. As the delPendingUser will not be able to access the same |
597 | // thread_local variable, we cannot address this problem by making |
598 | // pendingUsers_ thread_local. Instead, pendingUsers_ and userTable_ share |
599 | // the same PendingUserState shared_ptr. |
600 | userTable_.push_back(state); |
601 | } |
602 | |
603 | std::lock_guard<std::mutex> lock(mutex_); |
604 | TORCH_INTERNAL_ASSERT( |
605 | pendingUsers_.find(forkId) == pendingUsers_.end(), |
606 | "Inconsistent states: attempt to add the same UserRRef twice." ); |
607 | |
608 | pendingUsers_.emplace( |
609 | std::piecewise_construct, |
610 | std::forward_as_tuple(forkId), |
611 | std::forward_as_tuple(state)); |
612 | } |
613 | |
614 | void RRefContext::delPendingUser(const ForkId& forkId) { |
615 | std::shared_ptr<PendingUserState> deletedState = nullptr; |
616 | { |
617 | std::lock_guard<std::mutex> lock(mutex_); |
618 | auto iter = pendingUsers_.find(forkId); |
619 | TORCH_INTERNAL_ASSERT( |
620 | iter != pendingUsers_.end(), |
621 | "Inconsistent states: attempt to delete a non-exist UserRRef." ); |
622 | |
623 | // There are two reasons for keeping the deleted PendingUserState alive |
624 | // until exiting the critical section. |
625 | // (1) Since this UserRRef is removed from the map, the refcount of this |
626 | // UserRRef could reach to 0. So the resource destructor |
627 | // (`release_resources()`) might be called, in which the lock is |
628 | // acquired again. Hence, it must be destructed with the lock released. |
629 | // To meet this constraint, we intentionally create a temporary pointer |
630 | // to increase the refcount of the deleted PendingUserState, extending |
631 | // its lifetime until lock released. |
632 | // (2) Since #34497, a user function only runs after all RRefs in the |
633 | // arguments are confirmed by their owners, which is done by adding the |
634 | // RPC processing logic as a callback to the UserRRef ready future. So, |
635 | // calling `confirm` on the PendingUserState could trigger pending user |
636 | // functions, which might in turn acquire the lock in RRefContext. |
637 | // Hence, we must release the lock to prevent deadlock. |
638 | // NB: Another option is to use reentrant lock. However, it is better for |
639 | // the developers to fully understand the locking behavior instead of |
640 | // hiding the subtle logic using a reentrant lock. |
641 | deletedState = iter->second; // Increase refcount |
642 | |
643 | addConfirmedUser(forkId, iter->second->rref_); |
644 | pendingUsers_.erase(iter); // Decrease refcount. |
645 | } |
646 | deletedState->confirm(); |
647 | deleteAllUsersCV_.notify_all(); |
648 | deletedState.reset(); // Decrease refcount. |
649 | } |
650 | |
651 | void RRefContext::addConfirmedUser( |
652 | const ForkId& forkId, |
653 | const c10::intrusive_ptr<RRef>& rref) { |
654 | // Notice, caller need to hold the mutex for confirmedUsers_. |
655 | // std::lock_guard<std::mutex> lock(mutex_); |
656 | confirmedUsers_.emplace( |
657 | std::piecewise_construct, |
658 | std::forward_as_tuple(forkId), |
659 | std::forward_as_tuple(rref)); |
660 | } |
661 | |
662 | c10::intrusive_ptr<RRef> RRefContext::getPendingUser(const ForkId& forkId) { |
663 | std::lock_guard<std::mutex> lock(mutex_); |
664 | auto it = pendingUsers_.find(forkId); |
665 | if (it == pendingUsers_.end()) { |
666 | TORCH_INTERNAL_ASSERT( |
667 | false, "Pending user with forkId " , forkId, " not found" ); |
668 | } |
669 | return it->second->rref_; |
670 | } |
671 | |
672 | void RRefContext::recordThreadLocalPendingRRefs() { |
673 | TORCH_INTERNAL_ASSERT( |
674 | userTable_.empty(), |
675 | "User RRef Table should be empty when start recording" ); |
676 | recording_ = true; |
677 | } |
678 | |
679 | c10::intrusive_ptr<JitFuture> RRefContext::waitForThreadLocalPendingRRefs() { |
680 | // We need to set devices here, even if they won't be used by the value (it's |
681 | // a bool, it doesn't contain tensors!) because we need them to be propagated |
682 | // to child futures. This is silly and we should find a way to avoid this. |
683 | auto jitFuturePtr = |
684 | c10::make_intrusive<JitFuture>(BoolType::get(), agent_->getDevices()); |
685 | if (userTable_.empty()) { |
686 | jitFuturePtr->markCompleted(true); |
687 | } else { |
688 | auto remainingRRefs = |
689 | std::make_shared<std::atomic<uint64_t>>(userTable_.size()); |
690 | for (auto& state : userTable_) { |
691 | state->confirmationFuture_->addCallback( |
692 | [jitFuturePtr, remainingRRefs](JitFuture& /* unused */) { |
693 | auto localCount = remainingRRefs->fetch_sub(1); |
694 | if (localCount == 1) { |
695 | jitFuturePtr->markCompleted(true); |
696 | } |
697 | }); |
698 | } |
699 | userTable_.clear(); |
700 | } |
701 | recording_ = false; |
702 | return jitFuturePtr; |
703 | } |
704 | |
705 | void RRefContext::clearRecordedPendingRRefsOnError() { |
706 | userTable_.clear(); |
707 | recording_ = false; |
708 | } |
709 | |
710 | void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) { |
711 | delPendingUser(forkId); |
712 | ++numPendingFutures_; |
713 | auto jitFuture = agent_->sendWithRetries( |
714 | agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage()); |
715 | |
716 | jitFuture->addCallback([this](JitFuture& future) { |
717 | handleExceptionSilent(future); |
718 | --numPendingFutures_; |
719 | }); |
720 | } |
721 | |
722 | void RRefContext::addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref) { |
723 | std::lock_guard<std::mutex> lock(mutex_); |
724 | const auto& rrefId = rref->rrefId(); |
725 | owners_[rrefId] = rref; |
726 | auto& rrefForks = forks_[rrefId]; |
727 | TORCH_INTERNAL_ASSERT( |
728 | rrefForks.find(rrefId) == rrefForks.end(), |
729 | "Attempt to add self as fork twice " , |
730 | rrefId); |
731 | rrefForks.insert(rrefId); |
732 | } |
733 | |
734 | void RRefContext::addForkOfOwner(const RRefId& rrefId, const ForkId& forkId) { |
735 | std::lock_guard<std::mutex> lock(mutex_); |
736 | auto& rrefForks = forks_[rrefId]; |
737 | TORCH_INTERNAL_ASSERT( |
738 | rrefForks.find(forkId) == rrefForks.end(), |
739 | "Got fork notification twice on the same RRef " , |
740 | forkId); |
741 | rrefForks.insert(forkId); |
742 | } |
743 | |
744 | void RRefContext::addForkOfOwnerIfNotPresent( |
745 | const RRefId& rrefId, |
746 | const ForkId& forkId) { |
747 | std::lock_guard<std::mutex> lock(mutex_); |
748 | auto& rrefForks = forks_[rrefId]; |
749 | // We first check whether the child exists in rrefForks. It's possible |
750 | // the child may have been added by a previous send attempt, and this check |
751 | // (as opposed to an assertion here) ensures that messages that trigger this |
752 | // function are idempotent. |
753 | if (rrefForks.find(forkId) == rrefForks.end()) { |
754 | rrefForks.insert(forkId); |
755 | } else { |
756 | LOG(INFO) << "Ignoring duplicate request to add Fork of OwnerRRef with " |
757 | << "RRefId = " << rrefId << ", ForkId = " << forkId; |
758 | } |
759 | } |
760 | |
761 | c10::intrusive_ptr<RRef> RRefContext::delForkOfOwner( |
762 | const RRefId& rrefId, |
763 | const ForkId& forkId) { |
764 | c10::intrusive_ptr<RRef> deletedRRef; |
765 | bool ownerReduced = false; |
766 | // There were previously multiple TORCH_CHECKs in this function that checked |
767 | // whether the passed in fork was known by the user and whether the fork had |
768 | // already been deleted. These assertions are now replaced with nested if |
769 | // statements to ensure this function is idempotent. This makes it safe to |
770 | // retry RRefUserDelete messages. |
771 | { |
772 | std::lock_guard<std::mutex> lock(mutex_); |
773 | auto rrefIter = forks_.find(rrefId); |
774 | if (rrefIter != forks_.end()) { |
775 | auto& rrefForks = rrefIter->second; |
776 | auto forkIter = rrefForks.find(forkId); |
777 | if (forkIter != rrefForks.end()) { |
778 | rrefForks.erase(forkId); |
779 | } else { |
780 | LOG(INFO) |
781 | << "Could not find UserRRef instance, " |
782 | << "RRefId = " << rrefId << ", ForkId = " << forkId |
783 | << ", likely because it was deleted by a previously retried message" ; |
784 | } |
785 | if (rrefForks.empty()) { |
786 | auto ownerIter = owners_.find(rrefId); |
787 | if (ownerIter != owners_.end()) { |
788 | deletedRRef = ownerIter->second; |
789 | owners_.erase(ownerIter); |
790 | ownerReduced = true; |
791 | } |
792 | forks_.erase(rrefIter); |
793 | } |
794 | } else { |
795 | LOG(INFO) |
796 | << "Could not find OwnerRRef with RRefId = " << rrefId |
797 | << ", likely because it was deleted by a previously retried message" ; |
798 | } |
799 | } |
800 | if (ownerReduced) { |
801 | deleteAllUsersCV_.notify_all(); |
802 | } |
803 | return deletedRRef; |
804 | } |
805 | |
806 | } // namespace rpc |
807 | } // namespace distributed |
808 | } // namespace torch |
809 | |