1#include <torch/csrc/distributed/rpc/rref_impl.h>
2
3#include <ATen/record_function.h>
4#include <c10/core/impl/DeviceGuardImplInterface.h>
5#include <fmt/format.h>
6#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
7#include <torch/csrc/distributed/autograd/utils.h>
8#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
9#include <torch/csrc/distributed/rpc/rref_context.h>
10#include <torch/csrc/distributed/rpc/rref_proto.h>
11#include <torch/csrc/distributed/rpc/utils.h>
12
13namespace {
14// If the type is subtype of named type, return its qualifiedname, otherwise
15// return its type str.
16std::string getTypeStr(const c10::TypePtr& type) {
17 switch (type->kind()) {
18 case c10::TypeKind::FunctionType:
19 return type->castRaw<c10::FunctionType>()->name()->qualifiedName();
20 case c10::TypeKind::TupleType:
21 return type->castRaw<c10::TupleType>()->name()->qualifiedName();
22 case c10::TypeKind::ClassType:
23 return type->castRaw<c10::ClassType>()->name()->qualifiedName();
24 case c10::TypeKind::InterfaceType:
25 return type->castRaw<c10::InterfaceType>()->name()->qualifiedName();
26 default:
27 return type->annotation_str();
28 }
29}
30
31} // namespace
32
33namespace torch {
34namespace distributed {
35namespace rpc {
36
37std::atomic<local_id_t> RRefContext::nextLocalId_{0};
38
39////////////////////////// RRefForkData /////////////////////////////////
40
41RRefForkData::RRefForkData(
42 worker_id_t ownerId,
43 const RRefId& rrefId,
44 const ForkId& forkId,
45 worker_id_t parent,
46 std::string typeStr)
47 : ownerId_(ownerId),
48 rrefId_(rrefId),
49 forkId_(forkId),
50 parent_(parent),
51 typeStr_(std::move(typeStr)) {}
52
53////////////////////////////// RRef /////////////////////////////////////
54
55RRef::RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type)
56 : RRefInterface(),
57 ownerId_(ownerId),
58 rrefId_(rrefId),
59 type_(std::move(type)) {}
60
61RRefForkData RRef::fork() const {
62 auto& ctx = RRefContext::getInstance();
63 return RRefForkData(
64 ownerId_,
65 rrefId_,
66 ctx.genGloballyUniqueId(),
67 ctx.getWorkerId(),
68 getTypeStr(type_));
69}
70
71void RRef::handleError(RPCErrorType errorType, const JitFuture& jitFuture) {
72 static std::unordered_map<
73 RPCErrorType,
74 std::function<void(const JitFuture& jitFuture)>,
75 std::hash<int>>
76 errorHandlers = {
77 {RPCErrorType::TIMEOUT,
78 [this](const JitFuture& /* unused */) { setTimedOut(); }},
79 {RPCErrorType::INTENTIONAL_FAILURE,
80 [this](const JitFuture& /* unused */) { setTimedOut(); }},
81 {RPCErrorType::UNKNOWN_ERROR, [](const JitFuture& jitFuture) {
82 // Default error handler
83 RRefContext::handleException(jitFuture);
84 }}};
85 errorHandlers.find(errorType)->second(jitFuture);
86}
87
88////////////////////////// UserRRef /////////////////////////////////////
89
90UserRRef::UserRRef(
91 worker_id_t ownerId,
92 const RRefId& rrefId,
93 const ForkId& forkId,
94 TypePtr type)
95 : RRef(ownerId, rrefId, std::move(type)),
96 forkId_(forkId),
97 confirmedByOwner_(false) {
98 // Do nothing,
99 // (1) If this UserRRef is a fork of an existing RRef, RRefContext will send
100 // a RREF_FORK_REQUEST message to the owner.
101 // (2) If this the creator UserRRef, ScriptRemoteCall or PythonRemoteCall will
102 // properly notify the owner.
103}
104
105void UserRRef::tryDel() {
106 std::lock_guard<std::mutex> lockGuard(deletedOnOwnerMutex_);
107 if (!deletedOnOwner_) {
108 try {
109 RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_);
110 deletedOnOwner_ = true;
111 } catch (const std::exception& ex) {
112 LOG(ERROR) << "Error occurred when deleting" << *this << " : "
113 << ex.what();
114 } catch (...) {
115 LOG(ERROR) << "Error occurred when deleting" << *this << " : "
116 << "unknown error";
117 }
118 }
119}
120
121UserRRef::~UserRRef() {
122 tryDel();
123}
124
125void UserRRef::release_resources() {
126 tryDel();
127}
128
129const ForkId& UserRRef::forkId() const {
130 return forkId_;
131}
132
133IValue UserRRef::toHere(const float timeoutSeconds) const {
134 TORCH_CHECK(
135 !getTimedOut(),
136 "RRef creation via rpc.remote() timed out, and it "
137 "is possible that the RRef on the owner node does not exist.");
138 // see Note [Best-Effort Check on Deleted UserRRefs]
139 TORCH_CHECK(
140 !deletedOnOwner_,
141 *this,
142 " has been deleted. Cannot call to_here() on it after deletion.");
143 auto toHereKey = std::string("");
144 if (torch::autograd::profiler::profilerEnabled()) {
145 toHereKey = fmt::format(
146 "to_here#({})->({})",
147 RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_,
148 RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_).name_);
149 }
150 RECORD_USER_SCOPE(toHereKey);
151 TORCH_CHECK(
152 !type_->is_module(),
153 *this,
154 " is an RRef to a ScriptModule. "
155 "It can't be sent through RPC "
156 "from owner, ",
157 ownerWorkerInfo(),
158 ", to user, ",
159 RpcAgent::getCurrentRpcAgent()->getWorkerInfo(),
160 ".");
161
162 auto agent = RpcAgent::getCurrentRpcAgent();
163
164 // ScriptRRefFetchCall message always carries autograd context id even if
165 // the message itself does not contain any tensor, because the response would
166 // potentially contain tensors.
167 c10::intrusive_ptr<Message> msgToSend;
168
169 if (isPyObj()) {
170 msgToSend = PythonRRefFetchCall(ownerId_, rrefId()).toMessage();
171 } else {
172 msgToSend = ScriptRRefFetchCall(ownerId_, rrefId()).toMessage();
173 }
174
175 // toHere is profiled as a blocking call, and does not execute operations on
176 // the remote node. Hence, don't wrap it with a profiling message since we
177 // don't need the profiler to be enabled remotely.
178 auto jitFuture = autograd::sendMessageWithAutograd(
179 *agent,
180 agent->getWorkerInfo(ownerId_),
181 std::move(msgToSend),
182 true /* forceGradRecording */,
183 timeoutSeconds,
184 true /* forceDisableProfiling */);
185
186 // TODO: we should ideally be able to interrupt this blocking wait if we check
187 // getTimedOut() and it is true
188 // (https://github.com/pytorch/pytorch/issues/39411).
189 jitFuture->waitAndThrow();
190 auto messagePtr = jitFuture->constValue().toCustomClass<Message>();
191 MessageType msgType = messagePtr->type();
192 auto response = deserializeResponse(*messagePtr, msgType);
193 TORCH_INTERNAL_ASSERT(
194 msgType == MessageType::SCRIPT_RREF_FETCH_RET ||
195 msgType == MessageType::PYTHON_RREF_FETCH_RET,
196 "Message type should either be SCRIPT_RREF_FETCH_RET "
197 "or PYTHON_RREF_FETCH_RET");
198 RpcCommandBase& rpc = *response;
199 auto& rrefFetchRet = static_cast<RRefFetchRet&>(rpc);
200 if (isPyObj()) {
201 // wrap python serialized vector of ivalues into tuple, this
202 // made the C++ toHere interface to return single IValue
203 return ivalue::Tuple::create(rrefFetchRet.values());
204 } else {
205 return rrefFetchRet.values().front();
206 }
207}
208
209RRefForkData UserRRef::fork() const {
210 // Note [Best-Effort Check on Deleted UserRRefs]
211 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
212 // This check does not guarantee correctness, as there could be another thread
213 // trying to delete this UserRRef concurrently. Passing this check does not
214 // mean this RRef will be alive throughout this function. This is just our
215 // best-effort attempt to raise proper error messages. The behavior of using
216 // deleted UserRRefs is undefined.
217 //
218 // The reason for not implementing strict checks are:
219 // 1. This would need to acquire lock on deletedOnOwnerMutex_, which would
220 // introduce unnecessary overhead for most normal use cases.
221 // 2. This would introduce a lot of complexities to get the behavior correct.
222 // Assume we acquired the lock here, and there is another thread X block
223 // waiting in tryDel() on the lock. Exiting this fork function would
224 // unblock thread X. However, while X proceeds with deleting this UserRRef,
225 // the call site of fork() might have added the UserRRef to
226 // pendingChildren_ map, but up to this point, nothing prevents X from
227 // deleting this RRef even if it shouldn't do so due to the state change
228 // in pendingChildren_. We might be able to get it right for now by locking
229 // and checking pendingChildren_ in X, but the gain does not seem to
230 // worth the complexity.
231 TORCH_CHECK(
232 !deletedOnOwner_,
233 *this,
234 " has been deleted. Cannot call fork an UserRRef after deletion.");
235 return RRef::fork();
236}
237
238////////////////////////// OwnerRRef /////////////////////////////////////
239
240OwnerRRef::OwnerRRef(
241 worker_id_t ownerId,
242 const RRefId& rrefId,
243 TypePtr type,
244 std::vector<c10::Device> devices)
245 : OwnerRRef(ownerId, rrefId, type, /* value */ {}, std::move(devices)) {}
246
247OwnerRRef::OwnerRRef(
248 worker_id_t ownerId,
249 const RRefId& rrefId,
250 TypePtr type,
251 c10::optional<IValue> value,
252 std::vector<c10::Device> devices)
253 : RRef(ownerId, rrefId, type) {
254 future_ = c10::make_intrusive<JitFuture>(type_, std::move(devices));
255
256 if (value.has_value()) {
257 future_->markCompleted(value.value());
258 }
259}
260
261const IValue& OwnerRRef::getValue() const {
262 TORCH_CHECK(
263 !getTimedOut(),
264 "RRef creation via rpc.remote() timed out, and it "
265 "is possible that the RRef on the owner node does not exist.");
266 future_->waitAndThrow();
267 return future_->constValue();
268}
269
270bool OwnerRRef::hasValue() const {
271 return future_->completed();
272}
273
274c10::intrusive_ptr<JitFuture> OwnerRRef::getFuture() {
275 return future_;
276}
277
278void OwnerRRef::setValue(IValue&& value) {
279 future_->markCompleted(value);
280}
281
282void OwnerRRef::setError(std::exception_ptr eptr) {
283 future_->setErrorIfNeeded(std::move(eptr));
284}
285
286std::ostream& operator<<(std::ostream& os, const RRef& rref) {
287 if (rref.isOwner()) {
288 return os << "OwnerRRef("
289 << "rref_id=" << rref.rrefId() << ")";
290 } else {
291 return os << "UserRRef("
292 << "rref_id=" << rref.rrefId()
293 << ", fork_id=" << static_cast<const UserRRef*>(&rref)->forkId()
294 << ")";
295 }
296}
297
298} // namespace rpc
299} // namespace distributed
300} // namespace torch
301