1 | #include <torch/csrc/distributed/rpc/rref_impl.h> |
2 | |
3 | #include <ATen/record_function.h> |
4 | #include <c10/core/impl/DeviceGuardImplInterface.h> |
5 | #include <fmt/format.h> |
6 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h> |
7 | #include <torch/csrc/distributed/autograd/utils.h> |
8 | #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h> |
9 | #include <torch/csrc/distributed/rpc/rref_context.h> |
10 | #include <torch/csrc/distributed/rpc/rref_proto.h> |
11 | #include <torch/csrc/distributed/rpc/utils.h> |
12 | |
13 | namespace { |
14 | // If the type is subtype of named type, return its qualifiedname, otherwise |
15 | // return its type str. |
16 | std::string getTypeStr(const c10::TypePtr& type) { |
17 | switch (type->kind()) { |
18 | case c10::TypeKind::FunctionType: |
19 | return type->castRaw<c10::FunctionType>()->name()->qualifiedName(); |
20 | case c10::TypeKind::TupleType: |
21 | return type->castRaw<c10::TupleType>()->name()->qualifiedName(); |
22 | case c10::TypeKind::ClassType: |
23 | return type->castRaw<c10::ClassType>()->name()->qualifiedName(); |
24 | case c10::TypeKind::InterfaceType: |
25 | return type->castRaw<c10::InterfaceType>()->name()->qualifiedName(); |
26 | default: |
27 | return type->annotation_str(); |
28 | } |
29 | } |
30 | |
31 | } // namespace |
32 | |
33 | namespace torch { |
34 | namespace distributed { |
35 | namespace rpc { |
36 | |
37 | std::atomic<local_id_t> RRefContext::nextLocalId_{0}; |
38 | |
39 | ////////////////////////// RRefForkData ///////////////////////////////// |
40 | |
41 | RRefForkData::RRefForkData( |
42 | worker_id_t ownerId, |
43 | const RRefId& rrefId, |
44 | const ForkId& forkId, |
45 | worker_id_t parent, |
46 | std::string typeStr) |
47 | : ownerId_(ownerId), |
48 | rrefId_(rrefId), |
49 | forkId_(forkId), |
50 | parent_(parent), |
51 | typeStr_(std::move(typeStr)) {} |
52 | |
53 | ////////////////////////////// RRef ///////////////////////////////////// |
54 | |
55 | RRef::RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type) |
56 | : RRefInterface(), |
57 | ownerId_(ownerId), |
58 | rrefId_(rrefId), |
59 | type_(std::move(type)) {} |
60 | |
61 | RRefForkData RRef::fork() const { |
62 | auto& ctx = RRefContext::getInstance(); |
63 | return RRefForkData( |
64 | ownerId_, |
65 | rrefId_, |
66 | ctx.genGloballyUniqueId(), |
67 | ctx.getWorkerId(), |
68 | getTypeStr(type_)); |
69 | } |
70 | |
71 | void RRef::handleError(RPCErrorType errorType, const JitFuture& jitFuture) { |
72 | static std::unordered_map< |
73 | RPCErrorType, |
74 | std::function<void(const JitFuture& jitFuture)>, |
75 | std::hash<int>> |
76 | errorHandlers = { |
77 | {RPCErrorType::TIMEOUT, |
78 | [this](const JitFuture& /* unused */) { setTimedOut(); }}, |
79 | {RPCErrorType::INTENTIONAL_FAILURE, |
80 | [this](const JitFuture& /* unused */) { setTimedOut(); }}, |
81 | {RPCErrorType::UNKNOWN_ERROR, [](const JitFuture& jitFuture) { |
82 | // Default error handler |
83 | RRefContext::handleException(jitFuture); |
84 | }}}; |
85 | errorHandlers.find(errorType)->second(jitFuture); |
86 | } |
87 | |
88 | ////////////////////////// UserRRef ///////////////////////////////////// |
89 | |
90 | UserRRef::UserRRef( |
91 | worker_id_t ownerId, |
92 | const RRefId& rrefId, |
93 | const ForkId& forkId, |
94 | TypePtr type) |
95 | : RRef(ownerId, rrefId, std::move(type)), |
96 | forkId_(forkId), |
97 | confirmedByOwner_(false) { |
98 | // Do nothing, |
99 | // (1) If this UserRRef is a fork of an existing RRef, RRefContext will send |
100 | // a RREF_FORK_REQUEST message to the owner. |
101 | // (2) If this the creator UserRRef, ScriptRemoteCall or PythonRemoteCall will |
102 | // properly notify the owner. |
103 | } |
104 | |
105 | void UserRRef::tryDel() { |
106 | std::lock_guard<std::mutex> lockGuard(deletedOnOwnerMutex_); |
107 | if (!deletedOnOwner_) { |
108 | try { |
109 | RRefContext::getInstance().delUser(ownerId_, rrefId_, forkId_); |
110 | deletedOnOwner_ = true; |
111 | } catch (const std::exception& ex) { |
112 | LOG(ERROR) << "Error occurred when deleting" << *this << " : " |
113 | << ex.what(); |
114 | } catch (...) { |
115 | LOG(ERROR) << "Error occurred when deleting" << *this << " : " |
116 | << "unknown error" ; |
117 | } |
118 | } |
119 | } |
120 | |
121 | UserRRef::~UserRRef() { |
122 | tryDel(); |
123 | } |
124 | |
125 | void UserRRef::release_resources() { |
126 | tryDel(); |
127 | } |
128 | |
129 | const ForkId& UserRRef::forkId() const { |
130 | return forkId_; |
131 | } |
132 | |
133 | IValue UserRRef::toHere(const float timeoutSeconds) const { |
134 | TORCH_CHECK( |
135 | !getTimedOut(), |
136 | "RRef creation via rpc.remote() timed out, and it " |
137 | "is possible that the RRef on the owner node does not exist." ); |
138 | // see Note [Best-Effort Check on Deleted UserRRefs] |
139 | TORCH_CHECK( |
140 | !deletedOnOwner_, |
141 | *this, |
142 | " has been deleted. Cannot call to_here() on it after deletion." ); |
143 | auto toHereKey = std::string("" ); |
144 | if (torch::autograd::profiler::profilerEnabled()) { |
145 | toHereKey = fmt::format( |
146 | "to_here#({})->({})" , |
147 | RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_, |
148 | RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_).name_); |
149 | } |
150 | RECORD_USER_SCOPE(toHereKey); |
151 | TORCH_CHECK( |
152 | !type_->is_module(), |
153 | *this, |
154 | " is an RRef to a ScriptModule. " |
155 | "It can't be sent through RPC " |
156 | "from owner, " , |
157 | ownerWorkerInfo(), |
158 | ", to user, " , |
159 | RpcAgent::getCurrentRpcAgent()->getWorkerInfo(), |
160 | "." ); |
161 | |
162 | auto agent = RpcAgent::getCurrentRpcAgent(); |
163 | |
164 | // ScriptRRefFetchCall message always carries autograd context id even if |
165 | // the message itself does not contain any tensor, because the response would |
166 | // potentially contain tensors. |
167 | c10::intrusive_ptr<Message> msgToSend; |
168 | |
169 | if (isPyObj()) { |
170 | msgToSend = PythonRRefFetchCall(ownerId_, rrefId()).toMessage(); |
171 | } else { |
172 | msgToSend = ScriptRRefFetchCall(ownerId_, rrefId()).toMessage(); |
173 | } |
174 | |
175 | // toHere is profiled as a blocking call, and does not execute operations on |
176 | // the remote node. Hence, don't wrap it with a profiling message since we |
177 | // don't need the profiler to be enabled remotely. |
178 | auto jitFuture = autograd::sendMessageWithAutograd( |
179 | *agent, |
180 | agent->getWorkerInfo(ownerId_), |
181 | std::move(msgToSend), |
182 | true /* forceGradRecording */, |
183 | timeoutSeconds, |
184 | true /* forceDisableProfiling */); |
185 | |
186 | // TODO: we should ideally be able to interrupt this blocking wait if we check |
187 | // getTimedOut() and it is true |
188 | // (https://github.com/pytorch/pytorch/issues/39411). |
189 | jitFuture->waitAndThrow(); |
190 | auto messagePtr = jitFuture->constValue().toCustomClass<Message>(); |
191 | MessageType msgType = messagePtr->type(); |
192 | auto response = deserializeResponse(*messagePtr, msgType); |
193 | TORCH_INTERNAL_ASSERT( |
194 | msgType == MessageType::SCRIPT_RREF_FETCH_RET || |
195 | msgType == MessageType::PYTHON_RREF_FETCH_RET, |
196 | "Message type should either be SCRIPT_RREF_FETCH_RET " |
197 | "or PYTHON_RREF_FETCH_RET" ); |
198 | RpcCommandBase& rpc = *response; |
199 | auto& rrefFetchRet = static_cast<RRefFetchRet&>(rpc); |
200 | if (isPyObj()) { |
201 | // wrap python serialized vector of ivalues into tuple, this |
202 | // made the C++ toHere interface to return single IValue |
203 | return ivalue::Tuple::create(rrefFetchRet.values()); |
204 | } else { |
205 | return rrefFetchRet.values().front(); |
206 | } |
207 | } |
208 | |
209 | RRefForkData UserRRef::fork() const { |
210 | // Note [Best-Effort Check on Deleted UserRRefs] |
211 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
212 | // This check does not guarantee correctness, as there could be another thread |
213 | // trying to delete this UserRRef concurrently. Passing this check does not |
214 | // mean this RRef will be alive throughout this function. This is just our |
215 | // best-effort attempt to raise proper error messages. The behavior of using |
216 | // deleted UserRRefs is undefined. |
217 | // |
218 | // The reason for not implementing strict checks are: |
219 | // 1. This would need to acquire lock on deletedOnOwnerMutex_, which would |
220 | // introduce unnecessary overhead for most normal use cases. |
221 | // 2. This would introduce a lot of complexities to get the behavior correct. |
222 | // Assume we acquired the lock here, and there is another thread X block |
223 | // waiting in tryDel() on the lock. Exiting this fork function would |
224 | // unblock thread X. However, while X proceeds with deleting this UserRRef, |
225 | // the call site of fork() might have added the UserRRef to |
226 | // pendingChildren_ map, but up to this point, nothing prevents X from |
227 | // deleting this RRef even if it shouldn't do so due to the state change |
228 | // in pendingChildren_. We might be able to get it right for now by locking |
229 | // and checking pendingChildren_ in X, but the gain does not seem to |
230 | // worth the complexity. |
231 | TORCH_CHECK( |
232 | !deletedOnOwner_, |
233 | *this, |
234 | " has been deleted. Cannot call fork an UserRRef after deletion." ); |
235 | return RRef::fork(); |
236 | } |
237 | |
238 | ////////////////////////// OwnerRRef ///////////////////////////////////// |
239 | |
240 | OwnerRRef::OwnerRRef( |
241 | worker_id_t ownerId, |
242 | const RRefId& rrefId, |
243 | TypePtr type, |
244 | std::vector<c10::Device> devices) |
245 | : OwnerRRef(ownerId, rrefId, type, /* value */ {}, std::move(devices)) {} |
246 | |
247 | OwnerRRef::OwnerRRef( |
248 | worker_id_t ownerId, |
249 | const RRefId& rrefId, |
250 | TypePtr type, |
251 | c10::optional<IValue> value, |
252 | std::vector<c10::Device> devices) |
253 | : RRef(ownerId, rrefId, type) { |
254 | future_ = c10::make_intrusive<JitFuture>(type_, std::move(devices)); |
255 | |
256 | if (value.has_value()) { |
257 | future_->markCompleted(value.value()); |
258 | } |
259 | } |
260 | |
261 | const IValue& OwnerRRef::getValue() const { |
262 | TORCH_CHECK( |
263 | !getTimedOut(), |
264 | "RRef creation via rpc.remote() timed out, and it " |
265 | "is possible that the RRef on the owner node does not exist." ); |
266 | future_->waitAndThrow(); |
267 | return future_->constValue(); |
268 | } |
269 | |
270 | bool OwnerRRef::hasValue() const { |
271 | return future_->completed(); |
272 | } |
273 | |
274 | c10::intrusive_ptr<JitFuture> OwnerRRef::getFuture() { |
275 | return future_; |
276 | } |
277 | |
278 | void OwnerRRef::setValue(IValue&& value) { |
279 | future_->markCompleted(value); |
280 | } |
281 | |
282 | void OwnerRRef::setError(std::exception_ptr eptr) { |
283 | future_->setErrorIfNeeded(std::move(eptr)); |
284 | } |
285 | |
286 | std::ostream& operator<<(std::ostream& os, const RRef& rref) { |
287 | if (rref.isOwner()) { |
288 | return os << "OwnerRRef(" |
289 | << "rref_id=" << rref.rrefId() << ")" ; |
290 | } else { |
291 | return os << "UserRRef(" |
292 | << "rref_id=" << rref.rrefId() |
293 | << ", fork_id=" << static_cast<const UserRRef*>(&rref)->forkId() |
294 | << ")" ; |
295 | } |
296 | } |
297 | |
298 | } // namespace rpc |
299 | } // namespace distributed |
300 | } // namespace torch |
301 | |