1#pragma once
2
3#include <ATen/core/jit_type.h>
4#include <ATen/core/rref_interface.h>
5#include <c10/core/Event.h>
6#include <c10/util/Optional.h>
7#include <torch/csrc/distributed/rpc/message.h>
8#include <torch/csrc/distributed/rpc/rpc_agent.h>
9#include <torch/csrc/distributed/rpc/types.h>
10
11#include <atomic>
12
13namespace torch {
14namespace distributed {
15namespace rpc {
16
17class RRef;
18class RRefContext;
19class UserRRef;
20
21constexpr int OWNER_IDX = 0; // index of ownerId in the tuple
22constexpr int RREFID_ON_IDX = 1; // index of RRefId.createdOn_ in the tuple
23constexpr int RREFID_ID_IDX = 2; // index of RRefId.localId_ in the tuple
24constexpr int FORKID_ON_IDX = 3; // index of ForkId.createdOn_ in the tuple
25constexpr int FORKID_ID_IDX = 4; // index of ForkId.localId_ in the tuple
26constexpr int PARENT_IDX = 5; // index of parent in the tuple
27constexpr int TYPE_IDX = 6; // index of parent in the tuple
28
29// NB: if more fields are added, make sure this field is also bumped
30constexpr int RFD_TUPLE_SIZE = 7; // number of RRefForkData fields in py::tuple
31
32// Represents fork of an RRef to be sent over the wire.
33struct TORCH_API RRefForkData {
34 const worker_id_t ownerId_;
35 const RRefId rrefId_;
36 const ForkId forkId_;
37 const worker_id_t parent_;
38 const std::string typeStr_;
39
40 RRefForkData(
41 worker_id_t ownerId,
42 const RRefId& rrefId,
43 const ForkId& forkId,
44 worker_id_t parent,
45 std::string typeStr);
46};
47
48// Note [RRef Protocol]
49// ~~~~~~~~~~~~~~~~~~~~~~~~~~
50//
51// [Background]
52//
53// RRef stands for Remote REFerence. Each RRef is owned by a single worker
54// (i.e., owner) and can be used by multiple users. The owner stores the real
55// data referenced by its RRefs. RRef needs to support fast and scalable RPC.
56// Hence, in the design, we avoid using a single global master to keep RRef
57// states, instead owners will keep track of the global reference counts
58// for its RRefs. Every RRef can be uniquely identified by a global RRefId,
59// which is assigned at the time it is first created either on a user or on the
60// owner.
61//
62// On the owner worker, there is only one OwnerRRef instance, which contains the
63// real data, while on user workers, there can be as many UserRRefs as
64// necessary, and UserRRef does not hold the data. All usage on the OwnerRRef
65// should retrieve the unique OwnerRRef instance using the globally unique
66// RRefId. //A UserRRef will be created when it is used as an argument or return
67// value in dist.rpc or dist.remote call, but RRef forking and reference
68// counting (RC) are completely transparent to applications. Every UserRRef will
69// also have its globally unique ForkId.
70//
71// [Assumptions]
72//
73// 1. Transient Network Failures
74//
75// TODO: current RRef implementation does not tolerate failures
76//
77// The RRef design handles transient network failures by retrying
78// messages. Node crashes or permanent network partition is beyond the scope.
79// When those incidents occur, the application may take down all workers, revert
80// to the previous checkpoint, and resume training.
81//
82// 2. Non-idempotent UDFs
83//
84// We assume UDFs are not idempotent and therefore cannot be retried. However,
85// internal RRef control messages are idempotent and retried upon message
86// failure.
87//
88// TODO: RRef internal messages are not yet idempotent
89//
90// 3. Out of Order Message Delivery
91//
92// We do not assume message delivery order between any pair of nodes, because
93// both sender and receiver are using multiple threads. There is no guarantee on
94// which message will be processed first.
95//
96// [RRef Lifetime]
97//
98// The goal of the protocol is to delete an OwnerRRef at an appropriate time.
99// The right time to delete an OwnerRRef is when there are no living UserRRefs
100// and Python GC also agrees to delete the OwnerRRef instance on the owner. The
101// tricky part is to determine if there are any living UserRRefs.
102//
103// A user can get a UserRRef in three situations:
104//
105// (1). Receiving a UserRRef from the owner.
106// (2). Receiving a UserRRef from another user.
107// (3). Creating a new UserRRef owned by another worker.
108//
109// (1) is the simplest case where the owner initiates the fork, and hence it can
110// easily increment local RC. The only requirement is that any UserRRef must
111// notify the owner before destruction. Hence, we need the first guarantee:
112//
113// G1. The owner will be notified when any UserRRef is deleted.
114//
115// As messages might come delayed or out-of-order, we need more one guarantee to
116// make sure the delete message is not sent out too soon. Let us first introduce
117// a new concept. If A sends an RPC to B that involves an RRef, we call the RRef
118// on A the parent RRef and the RRef on B the child RRef.
119//
120// G2. Parent RRef cannot be deleted until the child RRef is confirmed by the
121// owner.
122//
123// Under (1), where the caller is UserRRef and callee is OwnerRRef, it simply
124// means that the user will not send out the delete message until all previous
125// messages are ACKed. Note that ACKed does not mean the owner finishes
126// executing the function, instead, it only means the owner has retrieved its
127// local OwnerRRef and about to pass it to the function, which is sufficient to
128// keep the OwnerRRef alive even if the delete message from the user arrives at
129// the owner before the function finishes execution.
130//
131// With (2) and (3), it is possible that the owner only partially knows the RRef
132// fork graph or not even knowing it at all. For example, the RRef could be
133// constructed on a user, and before the owner receives the RPC call, the
134// creator user might have already shared the RRef with other users, and those
135// users could further share the RRef. One invariant is that the fork graph of
136// any RRef is always a tree rooted at the owner, because forking an RRef always
137// creates a new RRef instance, and hence every RRef has a single parent. One
138// nasty detail is that when an RRef is created on a user, technically the owner
139// is not its parent but we still consider it that way and it does not break the
140// argument below.
141//
142// The owner's view on any node (fork) in the tree has three stages:
143//
144// 1) unknown -> 2) known -> 3) deleted.
145//
146// The owner's view on the entire tree keeps changing. The owner deletes its
147// OwnerRRef instance when it thinks there are no living UserRRefs, i.e., when
148// OwnerRRef is deleted, all UserRRefs could be either indeed deleted or
149// unknown. The dangerous case is when some forks are unknown and others are
150// deleted.
151//
152// G2 trivially guarantees that no parent UserRRef Y can be deleted before the
153// owner knows all of Y's children UserRRefs.
154//
155// However, it is possible that the child UserRRef Z may be deleted before the
156// owner knows its parent Y. More specifically, this can happen when all of Z's
157// messages are processed by the owner before all messages from Y, including the
158// delete message. Nevertheless, this does not cause any problem. Because, at
159// least one of Y's ancestor will be alive, and it will prevent the owner from
160// deleting the OwnerRRef. Consider the following example: (NB: this scenario
161// will no longer relevant when we block UDF until all RRefs are confirmed by
162// the owner)
163//
164// OwnerRRef -> A -> Y -> Z
165//
166// OwnerRRef forks to A, then A forks to Y, and Y forks to Z. Z can be deleted
167// without OwnerRRef knowing Y. However, the OwnerRRef will at least know A, as
168// the owner directly forks the RRef to A. A won't die before the owner knows Y.
169//
170// Things get a little trickier if the RRef is created on a user:
171//
172// OwnerRRef
173// ^
174// |
175// A -> Y -> Z
176//
177// If Z calls to_here on the UserRRef, the owner at least knows A when Z is
178// deleted, because otherwise to_here wouldn't finish. If Z does not call
179// to_here, it is possible that the owner receives all messages from Z before
180// any message from A and Y. In this case, as the real data of the OwnerRRef has
181// not been created yet, there is nothing to be deleted either. It is the same
182// as Z does not exist at all Hence, it's still OK.
183//
184// See #26759 for more details and discussions.
185//
186// TODO: make RRef an IValue, and edit createStackForSchema accordingly
187// TODO: make RRef system messages idempotent and retry on failures.
188//
189// ``RRef`` is the base type for both ``UserRRef`` and ``OwnerRRef``.
190// Each ``RRef`` has a globally unique ``RRefId``.
191class TORCH_API RRef : public RRefInterface {
192 public:
193 // RRef is made NOT copyable NOT movable to prevent messing up reference
194 // counting.
195 explicit RRef(const RRef& other) = delete;
196 explicit RRef(RRef&& other) = delete;
197 RRef& operator=(RRef&& other) = delete;
198
199 ~RRef() override = default;
200
201 // returns the worker id of the owner
202 inline worker_id_t owner() const override {
203 return ownerId_;
204 }
205
206 // returns the worker name of the owner
207 inline std::string ownerName() const override {
208 return RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_).name_;
209 }
210
211 // returns the worker info of the owner
212 inline WorkerInfo ownerWorkerInfo() const {
213 return RpcAgent::getCurrentRpcAgent()->getWorkerInfo(ownerId_);
214 }
215
216 // Returns the globally unique RRefId of this RRef
217 inline const RRefId& rrefId() const {
218 return rrefId_;
219 }
220
221 inline bool isPyObj() const {
222 return type_ == PyObjectType::get();
223 }
224 inline const TypePtr type() const override {
225 return type_;
226 }
227
228 // Save the future corresponding to the creation of this RRef on a remote
229 // node. Note that this is only set when processing requests invoked with
230 // rpc.remote. This is only used to get the future corresponding to the rref
231 // for profiling use cases.
232 inline void registerOwnerCreationFuture(c10::intrusive_ptr<JitFuture> fut) {
233 ownerCreationFuture_ = std::move(fut);
234 }
235
236 // Get the future corresponding to the creation of this rref.
237 inline c10::intrusive_ptr<JitFuture> getOwnerCreationFuture() const {
238 return ownerCreationFuture_;
239 }
240
241 // Check if creation of this RRef on owner node has timed out.
242 inline bool getTimedOut() const {
243 return timedOut_.load();
244 }
245
246 // Dispatches an error to the correct handler based on its RPCErrorType.
247 void handleError(RPCErrorType errorType, const JitFuture& JitFuture);
248
249 // Send delete UserRRef request to Owner,
250 // if the request hasn't been sent yet.
251 // There are 2 cases to call it,
252 // 1, Python GC decides end of UserRRef lifetime, calling destructor.
253 // 2, RPC module graceful shutdown calls it on all UserRRefs tracked
254 // in the RRefContext.
255 virtual void tryDel() {}
256
257 protected:
258 // Indicates that the creation of this RRef on owner node has timed out.
259 inline void setTimedOut() {
260 timedOut_ = true;
261 }
262 friend class RRefContext;
263
264 RRef(worker_id_t ownerId, const RRefId& rrefId, TypePtr type);
265
266 virtual RRefForkData fork() const;
267
268 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
269 const worker_id_t ownerId_;
270 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
271 const RRefId rrefId_;
272 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
273 std::atomic<bool> timedOut_{false};
274
275 // type field to denote the type of the element that the RRef is holding
276 // it could be any TypePtr that JIT support, including PyObjectType
277 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
278 const TypePtr type_;
279 // Future corresponding to request to create RRef on remote node.
280 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
281 c10::intrusive_ptr<JitFuture> ownerCreationFuture_;
282};
283
284// ``UserRRef`` represents a user of an RRef. Besides the ``RRefId``, each user
285// also has a globally unique ``ForkId`` to identify this user. ``UserRRef``
286// never owns the real value, the only way to get the value of the ``RRef`` is
287// to call ``to_here()`` and get a copy..
288class TORCH_API UserRRef final : public RRef {
289 public:
290 UserRRef(const UserRRef& other) = delete;
291 UserRRef(UserRRef&& other) = delete;
292 UserRRef& operator=(const UserRRef& other) = delete;
293 UserRRef& operator=(UserRRef&& other) = delete;
294
295 UserRRef(
296 worker_id_t ownerId,
297 const RRefId& rrefId,
298 const ForkId& forkId,
299 TypePtr type);
300
301 inline bool isOwner() const override {
302 return false;
303 }
304
305 inline bool confirmedByOwner() const override {
306 return confirmedByOwner_;
307 }
308
309 // Returns the globally unique ForkId of this RRef
310 const ForkId& forkId() const;
311
312 // Get of copy of the value from the ``OwnerRRef``. If the value is not ready
313 // yet, this call will block.
314 IValue toHere(
315 const float timeoutSeconds =
316 torch::distributed::rpc::kUnsetRpcTimeout) const;
317
318 void tryDel() override;
319
320 // Will be called when refcount reaches 0.
321 // Upon destruction, this ``UserRRef`` will tell the owner to deref.
322 void release_resources() override;
323
324 // Will be called when both refcount and weakcount reach 0. See
325 // https://github.com/pytorch/pytorch/blob/9116f02bebf3a5260feef5732d36c54ecb3b4033/c10/util/intrusive_ptr.h#L204
326 // This is called on destructing the wrapping intrusive_ptr_target instance
327 // and it's data members.
328 ~UserRRef() override;
329
330 private:
331 friend class RRefContext;
332
333 RRefForkData fork() const override;
334 inline void confirm() {
335 confirmedByOwner_ = true;
336 }
337
338 const ForkId forkId_;
339
340 // Indicates if this user has sent delete message to it's owner.
341 // Note, thread safety is needed because delete message could be sent by
342 // either the destructor called by Python garbage collection or RRefContext
343 // proactive cleanup on RPC graceful shutdown.
344 std::mutex deletedOnOwnerMutex_;
345 bool deletedOnOwner_{false};
346 // Indicating whether this UserRRef has been confirmed by its owner.
347 std::atomic<bool> confirmedByOwner_;
348};
349
350// Keep the template only on the derived class because ``RRefContext`` needs to
351// erase the type on ``RRef`` and keep them in one map.
352class TORCH_API OwnerRRef final : public RRef {
353 public:
354 OwnerRRef(const OwnerRRef& other) = delete;
355 OwnerRRef(OwnerRRef&& other) = delete;
356 OwnerRRef& operator=(const OwnerRRef& other) = delete;
357 OwnerRRef& operator=(OwnerRRef&& other) = delete;
358
359 OwnerRRef(
360 worker_id_t ownerId,
361 const RRefId& rrefId,
362 TypePtr type,
363 std::vector<c10::Device> devices);
364
365 OwnerRRef(
366 worker_id_t ownerId,
367 const RRefId& rrefId,
368 TypePtr type,
369 c10::optional<IValue> value,
370 std::vector<c10::Device> devices);
371
372 inline bool isOwner() const override {
373 return true;
374 }
375
376 // OwnerRRef is always confirmed, while UserRRef is only confirmed when the
377 // owner knows about it.
378 inline bool confirmedByOwner() const override {
379 return true;
380 }
381
382 // Get a constant reference of the real value. This method will block if the
383 // value is not ready. This method does not need GIL as it does not create
384 // any new py::object. It will throw if there is an error.
385 const IValue& getValue() const;
386
387 // Set the value of this ``OwnerRRef``. This method does not need GIL as it
388 // does not create any new py::object.
389 void setValue(IValue&& value);
390 // Sets the value of this ``OwnerRRef`` to contain an exception.
391 void setError(std::exception_ptr eptr);
392
393 // Has a value or error been set?
394 bool hasValue() const;
395 // Gets a future that is satisfied when the value or error is set.
396 c10::intrusive_ptr<JitFuture> getFuture();
397
398 private:
399 friend class RRefContext;
400
401 c10::intrusive_ptr<JitFuture> future_;
402};
403
404TORCH_API std::ostream& operator<<(std::ostream& os, const RRef& rref);
405
406// Helper function that casts from c10::RRefInterface to OwnerRRef
407inline TORCH_API c10::intrusive_ptr<OwnerRRef> fromRRefInterface(
408 const c10::intrusive_ptr<c10::RRefInterface>& rrefInterface) {
409 return c10::static_intrusive_pointer_cast<OwnerRRef>(rrefInterface);
410}
411
412// Helper function that casts from OwnerRRef to c10::RRefInterface
413inline TORCH_API c10::intrusive_ptr<c10::RRefInterface> fromOwnerRRef(
414 const c10::intrusive_ptr<RRef>& ownerRRef) {
415 return c10::static_intrusive_pointer_cast<c10::RRefInterface>(ownerRRef);
416}
417
418} // namespace rpc
419} // namespace distributed
420} // namespace torch
421