1#include <torch/csrc/distributed/rpc/rref_context.h>
2#include <torch/csrc/distributed/rpc/rref_proto.h>
3#include <torch/csrc/distributed/rpc/utils.h>
4
5#include <sstream>
6
7namespace torch {
8namespace distributed {
9namespace rpc {
10
11thread_local std::vector<std::shared_ptr<RRefContext::PendingUserState>>
12 RRefContext::userTable_;
13thread_local bool RRefContext::recording_ = false;
14
15namespace callback {
16void confirmPendingUser(
17 const JitFuture& jitFuture,
18 const ForkId& expectedForkId) {
19 if (!jitFuture.hasError()) {
20 auto msgPtr = jitFuture.constValue().toCustomClass<Message>();
21 auto msgType = msgPtr->type();
22 auto rpc = deserializeResponse(*msgPtr, msgType);
23 auto& rr = dynamic_cast<RemoteRet&>(*rpc);
24 TORCH_INTERNAL_ASSERT(rr.forkId() == expectedForkId);
25 } else {
26 // Handle errors, such as timeouts, by invoking the error handler on the
27 // rref.
28 // Note [Best Effort Error handling for Remote calls]:
29 // When remote calls initiated by rpc.remote() fail, such as with a timeout
30 // error, we take a best-effort approach to error handling. We handle errors
31 // when callbacks corresponding to the remote call run, and set the error
32 // information on the RRef. If the RRef has not been used by the application
33 // before this process (such as to_here or fork call), then future uses of
34 // the RRef will appropriately raise errors. However, it is possible that
35 // the user application will use the RRef before the errors are handled. In
36 // this case, errors may not be raised as they have not yet been handled.
37 auto rref_ptr = RRefContext::getInstance().getPendingUser(expectedForkId);
38 auto errorType = getRPCErrorType(jitFuture);
39 rref_ptr->handleError(errorType, jitFuture);
40 }
41 RRefContext::getInstance().delPendingUser(expectedForkId);
42}
43
44c10::intrusive_ptr<RRef> finishCreatingOwnerRRef(
45 const JitFuture& jitFuture,
46 const RRefId& rrefId) {
47 if (jitFuture.hasError()) {
48 auto& ctx = RRefContext::getInstance();
49 // We expect to run this callback only after the OwnerRRef has been created,
50 // since this is only invoked when sending to self.
51 auto rref_ptr =
52 fromRRefInterface(ctx.getOwnerRRef(rrefId, /* foreCreated */ true)
53 ->constValue()
54 .toRRef());
55 auto errorType = getRPCErrorType(jitFuture);
56 rref_ptr->handleError(errorType, jitFuture);
57 // OwnerRRefs do not have a forkId, so don't need to assert here.
58 auto deletedRRef =
59 ctx.delForkOfOwner(rref_ptr->rrefId(), rref_ptr->rrefId());
60 return deletedRRef;
61 } else {
62 auto msgPtr = jitFuture.constValue().toCustomClass<Message>();
63 auto msgType = msgPtr->type();
64 auto rpc = deserializeResponse(*msgPtr, msgType);
65 auto& rr = dynamic_cast<RemoteRet&>(*rpc);
66 TORCH_INTERNAL_ASSERT(
67 rr.rrefId() == rr.forkId(),
68 "Expecting an OwnerRRef as RemoteRet but got a fork.");
69 auto& ctx = RRefContext::getInstance();
70 auto deletedRRef = ctx.delForkOfOwner(rr.rrefId(), rr.rrefId());
71 return deletedRRef;
72 }
73}
74
75} // namespace callback
76
77// Keys for RRef-related debug information.
78const std::string kNumOwnerRRefs = "num_owner_rrefs";
79const std::string kNumPendingFutures = "num_pending_futures";
80const std::string kNumPendingUsers = "num_pending_users";
81const std::string kNumForks = "num_forks";
82
83RRefContext& RRefContext::getInstance() {
84 // Leaky singleton to avoid module destructor races.
85 static RRefContext* context = new RRefContext(RpcAgent::getCurrentRpcAgent());
86 return *context;
87}
88
89std::vector<c10::intrusive_ptr<RRef>> RRefContext::destroyInstance(
90 bool ignoreRRefLeak) {
91 auto& ctx = RRefContext::getInstance();
92 {
93 std::lock_guard<std::mutex> lock(ctx.destroyedMutex_);
94 ctx.destroyed_ = true;
95 }
96 ctx.checkRRefLeaks(ignoreRRefLeak);
97 std::vector<c10::intrusive_ptr<RRef>> deletedRRefs;
98 for (auto& entry : ctx.owners_) {
99 auto rref = entry.second;
100 if (rref->isPyObj()) {
101 deletedRRefs.emplace_back(std::move(rref));
102 }
103 }
104 ctx.owners_.clear();
105 ctx.pendingOwners_.clear();
106 return deletedRRefs;
107}
108
109void RRefContext::handleException(const JitFuture& jitFuture) {
110 if (jitFuture.hasError()) {
111 auto errMsg = jitFuture.tryRetrieveErrorMessage();
112 VLOG(1) << "Got exception: " << errMsg;
113 TORCH_CHECK(false, errMsg);
114 }
115}
116
117void RRefContext::handleExceptionSilent(const JitFuture& jitFuture) {
118 if (jitFuture.hasError()) {
119 auto errMsg = jitFuture.tryRetrieveErrorMessage();
120 VLOG(1) << "Got exception: " << errMsg;
121 TORCH_CHECK_MSG(false, errMsg);
122 }
123}
124
125RRefContext::RRefContext(std::shared_ptr<RpcAgent> agent)
126 : agent_(std::move(agent)) {}
127
128RRefContext::~RRefContext() {
129 if (!owners_.empty()) {
130 VLOG(1) << "Destructing RRefContext with non-empty OwnerRRef set. "
131 << "This would likely cause Python deref error. "
132 << "Make sure destroyInstance() is invoked before destruction.";
133 }
134}
135
136std::unordered_map<std::string, std::string> RRefContext::getDebugInfo() {
137 std::unordered_map<std::string, std::string> info;
138 std::unique_lock<std::mutex> lock(mutex_);
139 auto ownerSize = owners_.size();
140 auto numPendingUsers = pendingUsers_.size();
141 int numForks = 0;
142 for (const auto& owner : forks_) {
143 numForks += owner.second.size();
144 }
145 lock.unlock();
146 info[kNumOwnerRRefs] = c10::to_string(ownerSize);
147 info[kNumPendingFutures] = c10::to_string(numPendingFutures_.load());
148 info[kNumPendingUsers] = c10::to_string(numPendingUsers);
149 info[kNumForks] = c10::to_string(numForks);
150 return info;
151}
152
153void RRefContext::checkRRefLeaks(bool ignoreRRefLeak) {
154 if (!forks_.empty()) {
155 std::stringstream ss;
156 for (auto& entry : forks_) {
157 const RRefId& rrefId = entry.first;
158 for (const auto& forkId : entry.second) {
159 ss << "Leaking RRef " << rrefId << " with fork Id " << forkId
160 << std::endl;
161 }
162 }
163
164 LOG(WARNING)
165 << "Detected RRef Leaks during shutdown. This usually "
166 << "occurs when the application code still holds references to RRef "
167 << "instances when calling shutdown(). If the program has "
168 << "completed correctly and the process is exiting, it is OK to "
169 << "ignore these leaks. However, if you program will keep running "
170 << "after this, these leaks could result in memory leaks on RRef "
171 << "owners. Please make sure all RRefs are out of scope and Python "
172 << "GC has deleted them before calling shutdown(): \n"
173 << ss.str();
174 if (!ignoreRRefLeak) {
175 TORCH_CHECK(false, ss.str());
176 }
177 }
178}
179
180c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef(
181 worker_id_t ownerId,
182 const TypePtr& type) {
183 TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner.");
184 // Explicitly creating rrefId before forkId to make sure the order is
185 // deterministic, as the argument evaluation order is system and compiler
186 // dependent.
187 const auto rrefId = genGloballyUniqueId();
188 const auto forkId = genGloballyUniqueId();
189 return createUserRRef(ownerId, rrefId, forkId, type);
190}
191
192c10::intrusive_ptr<UserRRef> RRefContext::createUserRRef(
193 worker_id_t ownerId,
194 const RRefId& rrefId,
195 const ForkId& forkId,
196 const TypePtr& type) {
197 TORCH_CHECK(ownerId != getWorkerId(), "RRef owner cannot create user RRef.");
198 // RRefContext does not track user RRefs, it will be destructed when there
199 // is no shared_ptrs pointing to it.
200 //
201 // NB: cannot use make_shared here as the constructor of UserRRef is private.
202 // NB: This UserRRef has not been confirmed by the owner yet. This function's
203 // call site is responsible for adding this UserRRef to pendingUsers_.
204 // Currently, there are two call sites.
205 // (1) The creator user in python_functions.cpp
206 // (2) The callee user in RRefContext::notifyOwnerAndParentOfFork.
207 //
208 // The reason for not adding the pending user here is to put addPendingUser()
209 // close to where the RPC occurs, and it is more clear to pair it with
210 // deletePendingUser() in the response callback at the call site.
211 return c10::make_intrusive<UserRRef>(ownerId, rrefId, forkId, type);
212}
213
214void RRefContext::delUser(
215 const worker_id_t owner,
216 const RRefId& rrefId,
217 const ForkId& forkId) {
218 {
219 std::lock_guard<std::mutex> lock(destroyedMutex_);
220 if (!destroyed_) {
221 // Sending an RRefUserDelete causes the receiver to run delForkOfOwner,
222 // which is now idempotent. See the comment at RRefContext::delForkOfOwner
223 // for more details.
224 ++numPendingFutures_;
225 auto jitFuture = agent_->sendWithRetries(
226 agent_->getWorkerInfo(owner),
227 RRefUserDelete(rrefId, forkId).toMessage());
228
229 jitFuture->addCallback([this](JitFuture& future) {
230 handleExceptionSilent(future);
231 --numPendingFutures_;
232 });
233 }
234 }
235
236 std::lock_guard<std::mutex> lock(mutex_);
237 confirmedUsers_.erase(forkId);
238}
239
240void RRefContext::delAllUsersAndUnforkedOwners(
241 std::chrono::milliseconds timeoutMillis) {
242 // First, wait for all pending UserRRefs to be confirmed,
243 // one kind is pendingUsers_, which are shared from Owner,
244 // the other kind pendingChildren_, which are shared from another User.
245 std::unordered_map<ForkId, c10::weak_intrusive_ptr<RRef>, ForkId::Hash>
246 tempConfirmedUsers;
247 {
248 std::unique_lock<std::mutex> lock(mutex_);
249 bool noPending = deleteAllUsersCV_.wait_for(lock, timeoutMillis, [this]() {
250 return pendingUsers_.empty() && pendingChildren_.empty();
251 });
252 if (!noPending) {
253 LOG(ERROR)
254 << "Timed out waiting for pending UserRRefs to be confirmed by owner and parent.";
255 }
256 tempConfirmedUsers.swap(confirmedUsers_);
257 }
258
259 // Start sending UserRRef delete messages, after all pendings are confirmed.
260 // Note, there should be no new forkings in between, because it's assumed that
261 // this utility is called during graceful shutdown, where no new user RPCs can
262 // be initiaited anymore.
263 for (const auto& user : tempConfirmedUsers) {
264 c10::intrusive_ptr<RRef> rref_ptr = user.second.lock();
265 if (!rref_ptr) {
266 continue;
267 }
268 // tryDel() below will re-acquire lock, lock must be released here.
269 rref_ptr->tryDel();
270 }
271
272 // If an rref in the owners_ map has never been forked, we will never get a
273 // corresponding message from the forking node(s) telling us to delete the
274 // RRef. Hence we delete the RRef here. This can occur when a remote call is
275 // sent to self and times out.
276 {
277 std::unique_lock<std::mutex> lock(mutex_);
278 std::vector<RRefId> unforkedOwners;
279 for (const auto& it : owners_) {
280 auto rrefId = it.first;
281 if (forks_.find(rrefId) == forks_.end()) {
282 // Successful fork of owner was never processed.
283 unforkedOwners.push_back(rrefId);
284 }
285 }
286 for (auto& rrefId : unforkedOwners) {
287 LOG(INFO) << "Removing unforked OwnerRRef with RRefId: " << rrefId;
288 auto iter = owners_.find(rrefId);
289 TORCH_CHECK(
290 iter != owners_.end(),
291 c10::str("Did not find OwnerRRef with RRefId: ", rrefId));
292 owners_.erase(iter);
293 }
294 }
295 // Wait for this node to process all delete UserRRef messages it may get for
296 // the OwnerRRefs that exist on this node.
297 {
298 std::unique_lock<std::mutex> lock(mutex_);
299 bool noOwner = deleteAllUsersCV_.wait_for(
300 lock, timeoutMillis, [this]() { return owners_.empty(); });
301 if (!noOwner) {
302 LOG(ERROR) << "Timed out waiting for pending OwnerRRefs to be deleted.";
303 }
304 }
305}
306
307c10::intrusive_ptr<RRef> RRefContext::getOrCreateRRef(
308 const RRefForkData& rrefForkData,
309 const TypePtr& type) {
310 auto& ownerId = rrefForkData.ownerId_;
311 auto& rrefId = rrefForkData.rrefId_;
312 auto& forkId = rrefForkData.forkId_;
313 if (ownerId == getWorkerId()) {
314 return getOrCreateOwnerRRef(rrefId, type);
315 } else {
316 return createUserRRef(ownerId, rrefId, forkId, type);
317 }
318}
319
320c10::intrusive_ptr<OwnerRRef> RRefContext::getOrCreateOwnerRRef(
321 const RRefId& rrefId,
322 const TypePtr& type) {
323 std::lock_guard<std::mutex> lock(mutex_);
324 const auto iter = owners_.find(rrefId);
325 if (iter == owners_.end()) {
326 // Scenario (1) the first time this owner knows about this RRef
327 //
328 // NB: cannot use make_shared here as the constructor of OwnerRRef is
329 // private.
330 auto rref = c10::make_intrusive<OwnerRRef>(
331 getWorkerId(), rrefId, type, agent_->getDevices());
332 owners_[rref->rrefId()] = rref;
333 const auto pendingOwnerIter = pendingOwners_.find(rrefId);
334 if (pendingOwnerIter != pendingOwners_.end()) {
335 // cast to RRefInterface to hold it into IValue
336 auto rrefPtr = fromOwnerRRef(rref);
337 pendingOwnerIter->second->markCompleted(IValue(rrefPtr));
338 pendingOwners_.erase(pendingOwnerIter);
339 }
340 return rref;
341 } else {
342 // Scenario (2) retrieving an existing RRef
343 auto ownerRRef = fromRRefInterface(iter->second);
344 // Now double check if the two types match
345 //
346 // Why we are special casing the check for tensor type here?
347 // this is because tensor types might get specialized on tensors when
348 // we pass inputs to the function, i.e. TensorType can filled with
349 // specific shape info, requires_grad info, etc. so the OwerRRef we
350 // found might already have those infos, but the `type` we passed in
351 // here is a plain TensorType, they are not equal relationship:
352 // specialized TensorType <: plain TensorType
353 //
354 // In RPC we don't care the difference as we ser/de with just the
355 // plain TensorType. This is not a issue for UserRRef creation either,
356 // since Tensor can only get specialized with a previous run of local
357 // JIT function, and we shouldn't preserve the specialized SubTensorType
358 // information on other workers because it's only information only.
359 if (type->isSubtypeOf(*TensorType::get())) {
360 TORCH_INTERNAL_ASSERT(
361 ownerRRef->type()->isSubtypeOf(*TensorType::get()),
362 "Expect OwnerRRef to be a sub-type of TensorType, but got ",
363 ownerRRef->type()->repr_str());
364 } else {
365 TORCH_INTERNAL_ASSERT(
366 *ownerRRef->type() == *type,
367 "OwnerRRef type is ",
368 ownerRRef->type()->repr_str(),
369 ", expected type is ",
370 type->repr_str());
371 }
372 return ownerRRef;
373 }
374}
375
376c10::intrusive_ptr<OwnerRRef> RRefContext::createOwnerRRef(
377 const TypePtr& type) {
378 // Don't add this OnwerRRef to the owners_ map yet, otherwise
379 // it will never be removed from there. Instead, only add it to the
380 // map in prepareChildFork, in case this local RRef is being passed
381 // to another worker.
382 return c10::make_intrusive<OwnerRRef>(
383 getWorkerId(), genGloballyUniqueId(), type, agent_->getDevices());
384}
385
386c10::intrusive_ptr<JitFuture> RRefContext::getOwnerRRef(
387 const RRefId& rrefId,
388 bool forceCreated) {
389 std::unique_lock<std::mutex> lock(mutex_);
390 const auto iter = owners_.find(rrefId);
391 if (iter == owners_.end()) {
392 if (forceCreated) {
393 TORCH_INTERNAL_ASSERT(
394 false,
395 c10::str("Expected OwnerRRef with id ", rrefId, " to be created."));
396 }
397 // Scenario (1) RRef is used before it is created
398 const auto pendingOwnerIter = pendingOwners_.find(rrefId);
399 if (pendingOwnerIter == pendingOwners_.end()) {
400 // Note: The type passed into RRefType::create() does not matter here, as
401 // the future is marked as completed with the RRef of the specific type
402 // in getOrCreateOwnerRRef().
403 // We need to set devices here, even if they won't be used by the value
404 // (an RRef object doesn't contain any tensors, it just provides means to
405 // retrieve them) because we need them to be propagated/ to child futures.
406 // This is silly and we should find a way to avoid this.
407 auto futureOwner = c10::make_intrusive<JitFuture>(
408 RRefType::create(c10::AnyType::get()), agent_->getDevices());
409 pendingOwners_[rrefId] = futureOwner;
410 return futureOwner;
411 } else {
412 return pendingOwnerIter->second;
413 }
414 } else {
415 // Scenario (2) retrieving an existing RRef
416 // Marks IValue Future as completed with the RRef IValue.
417 auto owner = iter->second;
418 auto rrefPtr = fromOwnerRRef(owner);
419
420 // We need to set devices here, even if they won't be used by the value (an
421 // RRef object doesn't contain any tensors, it just provides means to
422 // retrieve them) because we need them to be propagated/ to child futures.
423 // This is silly and we should find a way to avoid this.
424 auto futureOwner = c10::make_intrusive<JitFuture>(
425 RRefType::create(owner->type()), agent_->getDevices());
426 futureOwner->markCompleted(IValue(rrefPtr));
427 return futureOwner;
428 }
429}
430
431RRefForkData RRefContext::prepareChildFork(
432 const c10::intrusive_ptr<RRef>& rref) {
433 // If we know that rref creation on the owner has timed out, raise it to the
434 // user here, otherwise continue with pickling.
435
436 TORCH_CHECK(
437 !rref->getTimedOut(),
438 "RRef creation via rpc.remote() timed out, and it "
439 "is possible that the RRef on the owner node does not exist.");
440 auto rrefForkData = rref->fork();
441 if (rref->isOwner()) {
442 // Note [Early Fork Registration]
443 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
444 // If the parent (caller) is the owner, directly register the fork, instead
445 // of waiting for another RREF_FORK_REQUEST or RREF_CHILD_ACCEPT message. An
446 // Alternative is adding the fork when the callee user ACKs. However, before
447 // that, the owner still have to adds the OwnerRRef into some map to keep it
448 // alive (e.g., in pendingChildren_). Hence, adding the fork here or in the
449 // ACK does not making any difference but only add complexity.
450 // TODO: When adding failure retries and timeout, this fork needs to be
451 // deleted if the owner does not receive the ACK within the timeout.
452 addForkOfOwner(rrefForkData.rrefId_, rrefForkData.forkId_);
453 // ensure that this RRef is in the owners_ list to keep it alive.
454 // this is needed for OwnerRRefs that were created locally.
455 {
456 std::lock_guard<std::mutex> lock(mutex_);
457 owners_[rref->rrefId()] = rref;
458 }
459 } else {
460 // Note [Useful Phantom Fork ID for User to Owner Call]
461 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
462 // If the callee of dist.remote or dist.rpc is the owner of this RRef, the
463 // callee will not create a fork using this rrefForkData.forkId_, because
464 // the owner will only keep one `OwnerRRef` instance and will not create any
465 // `UserRRef` instances. However, this rrefForkData.forkId_ is still
466 // necessary, as the caller user needs to keep this `UserRRef` alive until
467 // it gets the ACK from the callee owner. Otherwise, the delete message
468 // could arrive at the owner before this dist.rpc or dist.remote call, which
469 // could potentially trigger the `OwnerRRef` to be deleted before running
470 // the user code.
471 addPendingChild(rrefForkData.forkId_, rref);
472 }
473 return rrefForkData;
474}
475
476void RRefContext::notifyOwnerAndParentOfFork(
477 const ForkId& forkId,
478 worker_id_t parent,
479 const c10::intrusive_ptr<RRef>& rref) {
480 // Fork is shared from owner.
481 if (parent == rref->owner()) {
482 if (parent == agent_->getWorkerInfo().id_) {
483 // Owner sending RRef to self, remove the forkId as it was added during
484 // pickling
485 auto deletedRRef = delForkOfOwner(rref->rrefId(), forkId);
486 if (deletedRRef) {
487 TORCH_INTERNAL_ASSERT(
488 deletedRRef->rrefId() == rref->rrefId(),
489 "Deleting a fork of ",
490 rref->rrefId(),
491 " triggered deleting the OwnerRRef of ",
492 deletedRRef->rrefId());
493 // NB: not necessary to reset deletedRRef as rref is another shared_ptr
494 // instance pointing to the same OwnerRRef.
495 }
496 } else {
497 // If the parent is the owner, this fork has already been added into the
498 // forks_ map when the owner sends the message to the callee user.
499 // Hence, it is not necessary to send another RREF_CHILD_ACCEPT or
500 // RREF_FORK_REQUEST back to the owner. See Note [Early Fork
501 // Registration].
502 std::lock_guard<std::mutex> lock(mutex_);
503 addConfirmedUser(forkId, rref);
504 }
505 return;
506 }
507
508 // Fork is shared from user.
509 if (rref->isOwner()) {
510 // See Note [Useful Phantom Fork ID for User to Owner Call]
511 // In this case, the owner is the caller, and it does not add the fork id
512 // into forks_. Because, there will be no real `UserRRef` associated
513 // with this fork ID.
514 ++numPendingFutures_;
515 auto jitFuture = agent_->sendWithRetries(
516 agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
517 jitFuture->addCallback([this](JitFuture& future) {
518 handleExceptionSilent(future);
519 --numPendingFutures_;
520 });
521 } else {
522 ++numPendingFutures_;
523 auto jitFuture = agent_->sendWithRetries(
524 agent_->getWorkerInfo(rref->owner()),
525 RRefForkRequest(rref->rrefId(), forkId).toMessage());
526
527 addPendingUser(forkId, rref);
528
529 jitFuture->addCallback([this, forkId, parent](JitFuture& future) {
530 handleException(future);
531 this->finishForkRequest(forkId, parent);
532 // Decrease after calling finishForkRequest because, as that creates a new
533 // future, it might otherwise cause the count to briefly go to zero.
534 --numPendingFutures_;
535 });
536 }
537}
538
539void RRefContext::addPendingChild(
540 const ForkId& forkId,
541 const c10::intrusive_ptr<RRef>& rref) {
542 // see Note [Early Fork Registration]
543 // If the parent is the owner, it should directly add the child UserRRef as a
544 // fork.
545 TORCH_INTERNAL_ASSERT(
546 !rref->isOwner(), "OwnerRRef should not have a pending child.");
547 std::lock_guard<std::mutex> lock(mutex_);
548 TORCH_INTERNAL_ASSERT(
549 pendingChildren_.find(forkId) == pendingChildren_.end(),
550 "Inconsistent states: attempt to add the same child fork twice.");
551 pendingChildren_[forkId] = rref;
552}
553
554void RRefContext::delPendingChild(const ForkId& forkId) {
555 c10::intrusive_ptr<RRef> deletedUser;
556 {
557 std::lock_guard<std::mutex> lock(mutex_);
558 auto iter = pendingChildren_.find(forkId);
559 // We first check whether the child exists in pendingChildren_. It's
560 // possible the child may have been removed by a previous send attempt, and
561 // this check (as opposed to an assertion here) ensures that messages that
562 // trigger this function are idempotent.
563 if (iter != pendingChildren_.end()) {
564 // Since this UserRRef is removed from the map,
565 // the refcount of this UserRRef could reach to 0,
566 // so the "destructor", `release_resources()`, might be called,
567 // in which the lock is acquired again.
568 // So it must be destructed with the lock released.
569 // Meet this constraint by creating a temporary pointer to increase the
570 // refcount, extending its lifetime until lock released.
571 deletedUser = iter->second; // Increase refcount.
572 pendingChildren_.erase(iter); // Decrease refcount.
573 } else {
574 LOG(INFO) << "Ignoring duplicate request to delete child UserRRef with "
575 << "ForkId = " << forkId;
576 }
577 }
578 deleteAllUsersCV_.notify_all();
579 // The refcount of this UserRRef could reach to 0,
580 // so the "destructor", release_resources(), might be called,
581 // in which the lock is acquired again,
582 // so must destruct it with the lock released.
583 deletedUser.reset(); // Decrease refcount.
584}
585
586void RRefContext::addPendingUser(
587 const ForkId& forkId,
588 const c10::intrusive_ptr<RRef>& rref) {
589 TORCH_INTERNAL_ASSERT(
590 !rref->isOwner(), "Attempt to add an OwnerRRef as a pending User.");
591
592 auto state = std::make_shared<PendingUserState>(rref);
593 if (recording_) {
594 // adding and waiting for pending users are guaranteed to be called from the
595 // same thread, but deleting pending users will be called from another
596 // thread. As the delPendingUser will not be able to access the same
597 // thread_local variable, we cannot address this problem by making
598 // pendingUsers_ thread_local. Instead, pendingUsers_ and userTable_ share
599 // the same PendingUserState shared_ptr.
600 userTable_.push_back(state);
601 }
602
603 std::lock_guard<std::mutex> lock(mutex_);
604 TORCH_INTERNAL_ASSERT(
605 pendingUsers_.find(forkId) == pendingUsers_.end(),
606 "Inconsistent states: attempt to add the same UserRRef twice.");
607
608 pendingUsers_.emplace(
609 std::piecewise_construct,
610 std::forward_as_tuple(forkId),
611 std::forward_as_tuple(state));
612}
613
614void RRefContext::delPendingUser(const ForkId& forkId) {
615 std::shared_ptr<PendingUserState> deletedState = nullptr;
616 {
617 std::lock_guard<std::mutex> lock(mutex_);
618 auto iter = pendingUsers_.find(forkId);
619 TORCH_INTERNAL_ASSERT(
620 iter != pendingUsers_.end(),
621 "Inconsistent states: attempt to delete a non-exist UserRRef.");
622
623 // There are two reasons for keeping the deleted PendingUserState alive
624 // until exiting the critical section.
625 // (1) Since this UserRRef is removed from the map, the refcount of this
626 // UserRRef could reach to 0. So the resource destructor
627 // (`release_resources()`) might be called, in which the lock is
628 // acquired again. Hence, it must be destructed with the lock released.
629 // To meet this constraint, we intentionally create a temporary pointer
630 // to increase the refcount of the deleted PendingUserState, extending
631 // its lifetime until lock released.
632 // (2) Since #34497, a user function only runs after all RRefs in the
633 // arguments are confirmed by their owners, which is done by adding the
634 // RPC processing logic as a callback to the UserRRef ready future. So,
635 // calling `confirm` on the PendingUserState could trigger pending user
636 // functions, which might in turn acquire the lock in RRefContext.
637 // Hence, we must release the lock to prevent deadlock.
638 // NB: Another option is to use reentrant lock. However, it is better for
639 // the developers to fully understand the locking behavior instead of
640 // hiding the subtle logic using a reentrant lock.
641 deletedState = iter->second; // Increase refcount
642
643 addConfirmedUser(forkId, iter->second->rref_);
644 pendingUsers_.erase(iter); // Decrease refcount.
645 }
646 deletedState->confirm();
647 deleteAllUsersCV_.notify_all();
648 deletedState.reset(); // Decrease refcount.
649}
650
651void RRefContext::addConfirmedUser(
652 const ForkId& forkId,
653 const c10::intrusive_ptr<RRef>& rref) {
654 // Notice, caller need to hold the mutex for confirmedUsers_.
655 // std::lock_guard<std::mutex> lock(mutex_);
656 confirmedUsers_.emplace(
657 std::piecewise_construct,
658 std::forward_as_tuple(forkId),
659 std::forward_as_tuple(rref));
660}
661
662c10::intrusive_ptr<RRef> RRefContext::getPendingUser(const ForkId& forkId) {
663 std::lock_guard<std::mutex> lock(mutex_);
664 auto it = pendingUsers_.find(forkId);
665 if (it == pendingUsers_.end()) {
666 TORCH_INTERNAL_ASSERT(
667 false, "Pending user with forkId ", forkId, " not found");
668 }
669 return it->second->rref_;
670}
671
672void RRefContext::recordThreadLocalPendingRRefs() {
673 TORCH_INTERNAL_ASSERT(
674 userTable_.empty(),
675 "User RRef Table should be empty when start recording");
676 recording_ = true;
677}
678
679c10::intrusive_ptr<JitFuture> RRefContext::waitForThreadLocalPendingRRefs() {
680 // We need to set devices here, even if they won't be used by the value (it's
681 // a bool, it doesn't contain tensors!) because we need them to be propagated
682 // to child futures. This is silly and we should find a way to avoid this.
683 auto jitFuturePtr =
684 c10::make_intrusive<JitFuture>(BoolType::get(), agent_->getDevices());
685 if (userTable_.empty()) {
686 jitFuturePtr->markCompleted(true);
687 } else {
688 auto remainingRRefs =
689 std::make_shared<std::atomic<uint64_t>>(userTable_.size());
690 for (auto& state : userTable_) {
691 state->confirmationFuture_->addCallback(
692 [jitFuturePtr, remainingRRefs](JitFuture& /* unused */) {
693 auto localCount = remainingRRefs->fetch_sub(1);
694 if (localCount == 1) {
695 jitFuturePtr->markCompleted(true);
696 }
697 });
698 }
699 userTable_.clear();
700 }
701 recording_ = false;
702 return jitFuturePtr;
703}
704
705void RRefContext::clearRecordedPendingRRefsOnError() {
706 userTable_.clear();
707 recording_ = false;
708}
709
710void RRefContext::finishForkRequest(const ForkId& forkId, worker_id_t parent) {
711 delPendingUser(forkId);
712 ++numPendingFutures_;
713 auto jitFuture = agent_->sendWithRetries(
714 agent_->getWorkerInfo(parent), RRefChildAccept(forkId).toMessage());
715
716 jitFuture->addCallback([this](JitFuture& future) {
717 handleExceptionSilent(future);
718 --numPendingFutures_;
719 });
720}
721
722void RRefContext::addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref) {
723 std::lock_guard<std::mutex> lock(mutex_);
724 const auto& rrefId = rref->rrefId();
725 owners_[rrefId] = rref;
726 auto& rrefForks = forks_[rrefId];
727 TORCH_INTERNAL_ASSERT(
728 rrefForks.find(rrefId) == rrefForks.end(),
729 "Attempt to add self as fork twice ",
730 rrefId);
731 rrefForks.insert(rrefId);
732}
733
734void RRefContext::addForkOfOwner(const RRefId& rrefId, const ForkId& forkId) {
735 std::lock_guard<std::mutex> lock(mutex_);
736 auto& rrefForks = forks_[rrefId];
737 TORCH_INTERNAL_ASSERT(
738 rrefForks.find(forkId) == rrefForks.end(),
739 "Got fork notification twice on the same RRef ",
740 forkId);
741 rrefForks.insert(forkId);
742}
743
744void RRefContext::addForkOfOwnerIfNotPresent(
745 const RRefId& rrefId,
746 const ForkId& forkId) {
747 std::lock_guard<std::mutex> lock(mutex_);
748 auto& rrefForks = forks_[rrefId];
749 // We first check whether the child exists in rrefForks. It's possible
750 // the child may have been added by a previous send attempt, and this check
751 // (as opposed to an assertion here) ensures that messages that trigger this
752 // function are idempotent.
753 if (rrefForks.find(forkId) == rrefForks.end()) {
754 rrefForks.insert(forkId);
755 } else {
756 LOG(INFO) << "Ignoring duplicate request to add Fork of OwnerRRef with "
757 << "RRefId = " << rrefId << ", ForkId = " << forkId;
758 }
759}
760
761c10::intrusive_ptr<RRef> RRefContext::delForkOfOwner(
762 const RRefId& rrefId,
763 const ForkId& forkId) {
764 c10::intrusive_ptr<RRef> deletedRRef;
765 bool ownerReduced = false;
766 // There were previously multiple TORCH_CHECKs in this function that checked
767 // whether the passed in fork was known by the user and whether the fork had
768 // already been deleted. These assertions are now replaced with nested if
769 // statements to ensure this function is idempotent. This makes it safe to
770 // retry RRefUserDelete messages.
771 {
772 std::lock_guard<std::mutex> lock(mutex_);
773 auto rrefIter = forks_.find(rrefId);
774 if (rrefIter != forks_.end()) {
775 auto& rrefForks = rrefIter->second;
776 auto forkIter = rrefForks.find(forkId);
777 if (forkIter != rrefForks.end()) {
778 rrefForks.erase(forkId);
779 } else {
780 LOG(INFO)
781 << "Could not find UserRRef instance, "
782 << "RRefId = " << rrefId << ", ForkId = " << forkId
783 << ", likely because it was deleted by a previously retried message";
784 }
785 if (rrefForks.empty()) {
786 auto ownerIter = owners_.find(rrefId);
787 if (ownerIter != owners_.end()) {
788 deletedRRef = ownerIter->second;
789 owners_.erase(ownerIter);
790 ownerReduced = true;
791 }
792 forks_.erase(rrefIter);
793 }
794 } else {
795 LOG(INFO)
796 << "Could not find OwnerRRef with RRefId = " << rrefId
797 << ", likely because it was deleted by a previously retried message";
798 }
799 }
800 if (ownerReduced) {
801 deleteAllUsersCV_.notify_all();
802 }
803 return deletedRRef;
804}
805
806} // namespace rpc
807} // namespace distributed
808} // namespace torch
809