1 | #pragma once |
2 | |
3 | #include <ATen/core/jit_type.h> |
4 | #include <ATen/core/rref_interface.h> |
5 | #include <c10/core/Event.h> |
6 | #include <c10/util/Optional.h> |
7 | #include <torch/csrc/distributed/rpc/message.h> |
8 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
9 | #include <torch/csrc/distributed/rpc/types.h> |
10 | |
11 | #include <atomic> |
12 | |
13 | namespace torch { |
14 | namespace distributed { |
15 | namespace rpc { |
16 | |
17 | class RRef; |
18 | class RRefContext; |
19 | class UserRRef; |
20 | |
21 | constexpr int OWNER_IDX = 0; // index of ownerId in the tuple |
22 | constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple |
23 | constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple |
24 | constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple |
25 | constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple |
26 | constexpr int PARENT_IDX = 5; // index of parent in the tuple |
27 | constexpr int TYPE_IDX = 6; // index of parent in the tuple |
28 | |
29 | // NB: if more fields are added, make sure this field is also bumped |
30 | constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple |
31 | |
32 | // Represents fork of an RRef to be sent over the wire. |
33 | struct TORCH_API RRefForkData { |
34 | const worker_id_t ownerId_; |
35 | const RRefId rrefId_; |
36 | const ForkId forkId_; |
37 | const worker_id_t parent_; |
38 | const std::string typeStr_; |
39 | |
40 | RRefForkData( |
41 | worker_id_t ownerId, |
42 | const RRefId& rrefId, |
43 | const ForkId& forkId, |
44 | worker_id_t parent, |
45 | std::string typeStr); |
46 | }; |
47 | |
48 | // Note [RRef Protocol] |
49 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~ |
50 | // |
51 | // [Background] |
52 | // |
53 | // RRef stands for Remote REFerence. Each RRef is owned by a single worker |
54 | // (i.e., owner) and can be used by multiple users. The owner stores the real |
55 | // data referenced by its RRefs. RRef needs to support fast and scalable RPC. |
56 | // Hence, in the design, we avoid using a single global master to keep RRef |
57 | // states, instead owners will keep track of the global reference counts |
58 | // for its RRefs. Every RRef can be uniquely identified by a global RRefId, |
59 | // which is assigned at the time it is first created either on a user or on the |
60 | // owner. |
61 | // |
62 | // On the owner worker, there is only one OwnerRRef instance, which contains the |
63 | // real data, while on user workers, there can be as many UserRRefs as |
64 | // necessary, and UserRRef does not hold the data. All usage on the OwnerRRef |
65 | // should retrieve the unique OwnerRRef instance using the globally unique |
66 | // RRefId. //A UserRRef will be created when it is used as an argument or return |
67 | // value in dist.rpc or dist.remote call, but RRef forking and reference |
68 | // counting (RC) are completely transparent to applications. Every UserRRef will |
69 | // also have its globally unique ForkId. |
70 | // |
71 | // [Assumptions] |
72 | // |
73 | // 1. Transient Network Failures |
74 | // |
75 | // TODO: current RRef implementation does not tolerate failures |
76 | // |
77 | // The RRef design handles transient network failures by retrying |
78 | // messages. Node crashes or permanent network partition is beyond the scope. |
79 | // When those incidents occur, the application may take down all workers, revert |
80 | // to the previous checkpoint, and resume training. |
81 | // |
82 | // 2. Non-idempotent UDFs |
83 | // |
84 | // We assume UDFs are not idempotent and therefore cannot be retried. However, |
85 | // internal RRef control messages are idempotent and retried upon message |
86 | // failure. |
87 | // |
88 | // TODO: RRef internal messages are not yet idempotent |
89 | // |
90 | // 3. Out of Order Message Delivery |
91 | // |
92 | // We do not assume message delivery order between any pair of nodes, because |
93 | // both sender and receiver are using multiple threads. There is no guarantee on |
94 | // which message will be processed first. |
95 | // |
96 | // [RRef Lifetime] |
97 | // |
98 | // The goal of the protocol is to delete an OwnerRRef at an appropriate time. |
99 | // The right time to delete an OwnerRRef is when there are no living UserRRefs |
100 | // and Python GC also agrees to delete the OwnerRRef instance on the owner. The |
101 | // tricky part is to determine if there are any living UserRRefs. |
102 | // |
103 | // A user can get a UserRRef in three situations: |
104 | // |
105 | // (1). Receiving a UserRRef from the owner. |
106 | // (2). Receiving a UserRRef from another user. |
107 | // (3). Creating a new UserRRef owned by another worker. |
108 | // |
109 | // (1) is the simplest case where the owner initiates the fork, and hence it can |
110 | // easily increment local RC. The only requirement is that any UserRRef must |
111 | // notify the owner before destruction. Hence, we need the first guarantee: |
112 | // |
113 | // G1. The owner will be notified when any UserRRef is deleted. |
114 | // |
115 | // As messages might come delayed or out-of-order, we need more one guarantee to |
116 | // make sure the delete message is not sent out too soon. Let us first introduce |
117 | // a new concept. If A sends an RPC to B that involves an RRef, we call the RRef |
118 | // on A the parent RRef and the RRef on B the child RRef. |
119 | // |
120 | // G2. Parent RRef cannot be deleted until the child RRef is confirmed by the |
121 | // owner. |
122 | // |
123 | // Under (1), where the caller is UserRRef and callee is OwnerRRef, it simply |
124 | // means that the user will not send out the delete message until all previous |
125 | // messages are ACKed. Note that ACKed does not mean the owner finishes |
126 | // executing the function, instead, it only means the owner has retrieved its |
127 | // local OwnerRRef and about to pass it to the function, which is sufficient to |
128 | // keep the OwnerRRef alive even if the delete message from the user arrives at |
129 | // the owner before the function finishes execution. |
130 | // |
131 | // With (2) and (3), it is possible that the owner only partially knows the RRef |
132 | // fork graph or not even knowing it at all. For example, the RRef could be |
133 | // constructed on a user, and before the owner receives the RPC call, the |
134 | // creator user might have already shared the RRef with other users, and those |
135 | // users could further share the RRef. One invariant is that the fork graph of |
136 | // any RRef is always a tree rooted at the owner, because forking an RRef always |
137 | // creates a new RRef instance, and hence every RRef has a single parent. One |
138 | // nasty detail is that when an RRef is created on a user, technically the owner |
139 | // is not its parent but we still consider it that way and it does not break the |
140 | // argument below. |
141 | // |
142 | // The owner's view on any node (fork) in the tree has three stages: |
143 | // |
144 | // 1) unknown -> 2) known -> 3) deleted. |
145 | // |
146 | // The owner's view on the entire tree keeps changing. The owner deletes its |
147 | // OwnerRRef instance when it thinks there are no living UserRRefs, i.e., when |
148 | // OwnerRRef is deleted, all UserRRefs could be either indeed deleted or |
149 | // unknown. The dangerous case is when some forks are unknown and others are |
150 | // deleted. |
151 | // |
152 | // G2 trivially guarantees that no parent UserRRef Y can be deleted before the |
153 | // owner knows all of Y's children UserRRefs. |
154 | // |
155 | // However, it is possible that the child UserRRef Z may be deleted before the |
156 | // owner knows its parent Y. More specifically, this can happen when all of Z's |
157 | // messages are processed by the owner before all messages from Y, including the |
158 | // delete message. Nevertheless, this does not cause any problem. Because, at |
159 | // least one of Y's ancestor will be alive, and it will prevent the owner from |
160 | // deleting the OwnerRRef. Consider the following example: (NB: this scenario |
161 | // will no longer relevant when we block UDF until all RRefs are confirmed by |
162 | // the owner) |
163 | // |
164 | // OwnerRRef -> A -> Y -> Z |
165 | // |
166 | // OwnerRRef forks to A, then A forks to Y, and Y forks to Z. Z can be deleted |
167 | // without OwnerRRef knowing Y. However, the OwnerRRef will at least know A, as |
168 | // the owner directly forks the RRef to A. A won't die before the owner knows Y. |
169 | // |
170 | // Things get a little trickier if the RRef is created on a user: |
171 | // |
172 | // OwnerRRef |
173 | // ^ |
174 | // | |
175 | // A -> Y -> Z |
176 | // |
177 | // If Z calls to_here on the UserRRef, the owner at least knows A when Z is |
178 | // deleted, because otherwise to_here wouldn't finish. If Z does not call |
179 | // to_here, it is possible that the owner receives all messages from Z before |
180 | // any message from A and Y. In this case, as the real data of the OwnerRRef has |
181 | // not been created yet, there is nothing to be deleted either. It is the same |
182 | // as Z does not exist at all Hence, it's still OK. |
183 | // |
184 | // See #26759 for more details and discussions. |
185 | // |
186 | // TODO: make RRef an IValue, and edit createStackForSchema accordingly |
187 | // TODO: make RRef system messages idempotent and retry on failures. |
188 | // |
189 | // ``RRef`` is the base type for both ``UserRRef`` and ``OwnerRRef``. |
190 | // Each ``RRef`` has a globally unique ``RRefId``. |
191 | class TORCH_API RRef : public RRefInterface { |
192 | public: |
193 | // RRef is made NOT copyable NOT movable to prevent messing up reference |
194 | // counting. |
195 | explicit RRef(const RRef& other) = delete; |
196 | explicit RRef(RRef&& other) = delete; |
197 | RRef& operator=(RRef&& other) = delete; |
198 | |
199 | ~RRef() override = default; |
200 | |
201 | // returns the worker id of the owner |
202 | inline worker_id_t owner() const override { |
203 | return ownerId_; |
204 | } |
205 | |
206 | // returns the worker name of the owner |
207 | inline std::string ownerName() const override { |
208 | return RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_).name_; |
209 | } |
210 | |
211 | // returns the worker info of the owner |
212 | inline WorkerInfo ownerWorkerInfo() const { |
213 | return RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_); |
214 | } |
215 | |
216 | // Returns the globally unique RRefId of this RRef |
217 | inline const RRefId& rrefId() const { |
218 | return rrefId_; |
219 | } |
220 | |
221 | inline bool isPyObj() const { |
222 | return type_ == PyObjectType::get(); |
223 | } |
224 | inline const TypePtr type() const override { |
225 | return type_; |
226 | } |
227 | |
228 | // Save the future corresponding to the creation of this RRef on a remote |
229 | // node. Note that this is only set when processing requests invoked with |
230 | // rpc.remote. This is only used to get the future corresponding to the rref |
231 | // for profiling use cases. |
232 | inline void registerOwnerCreationFuture(c10::intrusive_ptr<JitFuture> fut) { |
233 | ownerCreationFuture_ = std::move(fut); |
234 | } |
235 | |
236 | // Get the future corresponding to the creation of this rref. |
237 | inline c10::intrusive_ptr<JitFuture> getOwnerCreationFuture() const { |
238 | return ownerCreationFuture_; |
239 | } |
240 | |
241 | // Check if creation of this RRef on owner node has timed out. |
242 | inline bool getTimedOut() const { |
243 | return timedOut_.load(); |
244 | } |
245 | |
246 | // Dispatches an error to the correct handler based on its RPCErrorType. |
247 | void handleError(RPCErrorType errorType, const JitFuture& JitFuture); |
248 | |
249 | // Send delete UserRRef request to Owner, |
250 | // if the request hasn't been sent yet. |
251 | // There are 2 cases to call it, |
252 | // 1, Python GC decides end of UserRRef lifetime, calling destructor. |
253 | // 2, RPC module graceful shutdown calls it on all UserRRefs tracked |
254 | // in the RRefContext. |
255 | virtual void tryDel() {} |
256 | |
257 | protected: |
258 | // Indicates that the creation of this RRef on owner node has timed out. |
259 | inline void setTimedOut() { |
260 | timedOut_ = true; |
261 | } |
262 | friend class RRefContext; |
263 | |
264 | RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type); |
265 | |
266 | virtual RRefForkData fork() const; |
267 | |
268 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
269 | const worker_id_t ownerId_; |
270 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
271 | const RRefId rrefId_; |
272 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
273 | std::atomic<bool> timedOut_{false}; |
274 | |
275 | // type field to denote the type of the element that the RRef is holding |
276 | // it could be any TypePtr that JIT support, including PyObjectType |
277 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
278 | const TypePtr type_; |
279 | // Future corresponding to request to create RRef on remote node. |
280 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
281 | c10::intrusive_ptr<JitFuture> ownerCreationFuture_; |
282 | }; |
283 | |
284 | // ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user |
285 | // also has a globally unique ``ForkId`` to identify this user. ``UserRRef`` |
286 | // never owns the real value, the only way to get the value of the ``RRef`` is |
287 | // to call ``to_here()`` and get a copy.. |
288 | class TORCH_API UserRRef final : public RRef { |
289 | public: |
290 | UserRRef(const UserRRef& other) = delete; |
291 | UserRRef(UserRRef&& other) = delete; |
292 | UserRRef& operator=(const UserRRef& other) = delete; |
293 | UserRRef& operator=(UserRRef&& other) = delete; |
294 | |
295 | UserRRef( |
296 | worker_id_t ownerId, |
297 | const RRefId& rrefId, |
298 | const ForkId& forkId, |
299 | TypePtr type); |
300 | |
301 | inline bool isOwner() const override { |
302 | return false; |
303 | } |
304 | |
305 | inline bool confirmedByOwner() const override { |
306 | return confirmedByOwner_; |
307 | } |
308 | |
309 | // Returns the globally unique ForkId of this RRef |
310 | const ForkId& forkId() const; |
311 | |
312 | // Get of copy of the value from the ``OwnerRRef``. If the value is not ready |
313 | // yet, this call will block. |
314 | IValue toHere( |
315 | const float timeoutSeconds = |
316 | torch::distributed::rpc::kUnsetRpcTimeout) const; |
317 | |
318 | void tryDel() override; |
319 | |
320 | // Will be called when refcount reaches 0. |
321 | // Upon destruction, this ``UserRRef`` will tell the owner to deref. |
322 | void release_resources() override; |
323 | |
324 | // Will be called when both refcount and weakcount reach 0. See |
325 | // https://github.com/pytorch/pytorch/blob/9116f02bebf3a5260feef5732d36c54ecb3b4033/c10/util/intrusive_ptr.h#L204 |
326 | // This is called on destructing the wrapping intrusive_ptr_target instance |
327 | // and it's data members. |
328 | ~UserRRef() override; |
329 | |
330 | private: |
331 | friend class RRefContext; |
332 | |
333 | RRefForkData fork() const override; |
334 | inline void confirm() { |
335 | confirmedByOwner_ = true; |
336 | } |
337 | |
338 | const ForkId forkId_; |
339 | |
340 | // Indicates if this user has sent delete message to it's owner. |
341 | // Note, thread safety is needed because delete message could be sent by |
342 | // either the destructor called by Python garbage collection or RRefContext |
343 | // proactive cleanup on RPC graceful shutdown. |
344 | std::mutex deletedOnOwnerMutex_; |
345 | bool deletedOnOwner_{false}; |
346 | // Indicating whether this UserRRef has been confirmed by its owner. |
347 | std::atomic<bool> confirmedByOwner_; |
348 | }; |
349 | |
350 | // Keep the template only on the derived class because ``RRefContext`` needs to |
351 | // erase the type on ``RRef`` and keep them in one map. |
352 | class TORCH_API OwnerRRef final : public RRef { |
353 | public: |
354 | OwnerRRef(const OwnerRRef& other) = delete; |
355 | OwnerRRef(OwnerRRef&& other) = delete; |
356 | OwnerRRef& operator=(const OwnerRRef& other) = delete; |
357 | OwnerRRef& operator=(OwnerRRef&& other) = delete; |
358 | |
359 | OwnerRRef( |
360 | worker_id_t ownerId, |
361 | const RRefId& rrefId, |
362 | TypePtr type, |
363 | std::vector<c10::Device> devices); |
364 | |
365 | OwnerRRef( |
366 | worker_id_t ownerId, |
367 | const RRefId& rrefId, |
368 | TypePtr type, |
369 | c10::optional<IValue> value, |
370 | std::vector<c10::Device> devices); |
371 | |
372 | inline bool isOwner() const override { |
373 | return true; |
374 | } |
375 | |
376 | // OwnerRRef is always confirmed, while UserRRef is only confirmed when the |
377 | // owner knows about it. |
378 | inline bool confirmedByOwner() const override { |
379 | return true; |
380 | } |
381 | |
382 | // Get a constant reference of the real value. This method will block if the |
383 | // value is not ready. This method does not need GIL as it does not create |
384 | // any new py::object. It will throw if there is an error. |
385 | const IValue& getValue() const; |
386 | |
387 | // Set the value of this ``OwnerRRef``. This method does not need GIL as it |
388 | // does not create any new py::object. |
389 | void setValue(IValue&& value); |
390 | // Sets the value of this ``OwnerRRef`` to contain an exception. |
391 | void setError(std::exception_ptr eptr); |
392 | |
393 | // Has a value or error been set? |
394 | bool hasValue() const; |
395 | // Gets a future that is satisfied when the value or error is set. |
396 | c10::intrusive_ptr<JitFuture> getFuture(); |
397 | |
398 | private: |
399 | friend class RRefContext; |
400 | |
401 | c10::intrusive_ptr<JitFuture> future_; |
402 | }; |
403 | |
404 | TORCH_API std::ostream& operator<<(std::ostream& os, const RRef& rref); |
405 | |
406 | // Helper function that casts from c10::RRefInterface to OwnerRRef |
407 | inline TORCH_API c10::intrusive_ptr<OwnerRRef> fromRRefInterface( |
408 | const c10::intrusive_ptr<c10::RRefInterface>& rrefInterface) { |
409 | return c10::static_intrusive_pointer_cast<OwnerRRef>(rrefInterface); |
410 | } |
411 | |
412 | // Helper function that casts from OwnerRRef to c10::RRefInterface |
413 | inline TORCH_API c10::intrusive_ptr<c10::RRefInterface> fromOwnerRRef( |
414 | const c10::intrusive_ptr<RRef>& ownerRRef) { |
415 | return c10::static_intrusive_pointer_cast<c10::RRefInterface>(ownerRRef); |
416 | } |
417 | |
418 | } // namespace rpc |
419 | } // namespace distributed |
420 | } // namespace torch |
421 | |