1#pragma once
2
3#include <c10/util/Optional.h>
4#include <torch/csrc/distributed/rpc/message.h>
5#include <torch/csrc/distributed/rpc/rpc_agent.h>
6#include <torch/csrc/distributed/rpc/rref_impl.h>
7#include <torch/csrc/distributed/rpc/types.h>
8#include <torch/csrc/distributed/rpc/utils.h>
9
10#include <atomic>
11
12namespace torch {
13namespace distributed {
14namespace rpc {
15
16namespace callback {
17// It's the callback for RemoteCall.
18void TORCH_API
19confirmPendingUser(const JitFuture& jitFuture, const ForkId& expectedForkId);
20
21// It's the callback for finishing creating owner rref, it returned deletedRRef,
22// so that the deletedRRef can be handled under GIL in python_functions.cpp if
23// deletedRRef contains python object.
24c10::intrusive_ptr<RRef> TORCH_API
25finishCreatingOwnerRRef(const JitFuture& jitFuture, const RRefId& rrefId);
26} // namespace callback
27
28// Manages RRef lifetime and keeps track of RRef forks.
29class TORCH_API RRefContext {
30 public:
31 static RRefContext& getInstance();
32 // NB: This method must be called before destructing RRefContext singleton.
33 // Similar to delForkOfOwner, this method returns a vector of OwnerRRefs that
34 // hold py::object. The call-site is also responsible for resetting those
35 // shared_ptr objects with a GIL. See comments at delForkOfOwner() for more
36 // details.
37 static std::vector<c10::intrusive_ptr<RRef>> destroyInstance(
38 bool ignoreRRefLeak = true);
39
40 static void handleException(const JitFuture& jitFuture);
41
42 // handle exception without throw ::c10::Error again
43 static void handleExceptionSilent(const JitFuture& jitFuture);
44
45 RRefContext(const RRefContext&) = delete;
46 RRefContext(RRefContext&& other) = delete;
47 void operator=(const RRefContext&) = delete;
48 RRefContext& operator=(RRefContext&& other) = delete;
49
50 ~RRefContext();
51
52 // get the worker id of the current worker
53 inline worker_id_t getWorkerId() const {
54 return agent_->getWorkerInfo().id_;
55 }
56
57 // get the worker name of the current worker
58 inline const std::string& getWorkerName() const {
59 return agent_->getWorkerInfo().name_;
60 }
61
62 // generate a globally unique ID
63 inline GloballyUniqueId genGloballyUniqueId() {
64 return GloballyUniqueId(getWorkerId(), nextLocalId_++);
65 }
66
67 inline const std::shared_ptr<RpcAgent>& agent() const {
68 return agent_;
69 }
70
71 // create a ``UserRRef`` owned by the worker ``ownerId``
72 c10::intrusive_ptr<UserRRef> createUserRRef(
73 worker_id_t ownerId,
74 const TypePtr& type);
75
76 // Convert an RRefForkData into an RRef. This RRef could be user or owner.
77 // This RRef could have already existed before, or could be created in this
78 // method, we pass type here to validate or help the rref creation.
79 c10::intrusive_ptr<RRef> getOrCreateRRef(
80 const RRefForkData& rfd,
81 const TypePtr& type);
82
83 // Get the ``OwnerRRef`` of id ``rrefId``. If it does not exist, create a new
84 // one. This function is called in two places:
85 // 1. when processing ``rpc.remote()``, i.e., ``SCRIPT_REMOTE_CALL``
86 // ``PYTHON_REMOTE_CALL``.
87 // 2. when unpickling ``OwnerRRef``.
88 // What's common in these two cases are, 1) the RRefId is already generated
89 // 2) the TypePtr is presented. So it can always create the ``OwnerRRef`` if
90 // it is not yet available.
91 c10::intrusive_ptr<OwnerRRef> getOrCreateOwnerRRef(
92 const RRefId& rrefId,
93 const TypePtr& type);
94
95 // Create an empty owner rref of type.
96 // This method is called to first time generate an ``OwnerRRef``, e.g.,
97 // 1) ``rpc.RRef(obj)``
98 // 2) create the ``OwnerRRef`` on `rpc.remote()` caller side.
99 // What's common in these two cases are, 1) the RRefId hasn't been generated
100 // 2) the TypePtr is presented.
101 c10::intrusive_ptr<OwnerRRef> createOwnerRRef(const TypePtr& type);
102
103 // Returns a Future of the OwnerRRef, which will be marked completed when
104 // ``OwnerRRef`` is created. This method is used when the TypePtr is not
105 // available, e.g., when processing to_here(). The forceCreated flag can be
106 // used to ensure that the rref is created on the owner, otherwise throw in
107 // cases where the user of this API expects this to return a completed future.
108 // Note that the return value is a intrusive_ptr to a c10::ivalue::Future that
109 // holds the RRef.
110 c10::intrusive_ptr<JitFuture> getOwnerRRef(
111 const RRefId& rrefId,
112 bool forceCreated = false);
113
114 // Adding the RRefId of an OwnerRRef into the forks_ map. This is useful when
115 // making a remote call to self, which as for now, still goes through serde
116 // and invokes request callback. In this case, the OwnerRRef has already been
117 // created on the send side, and we need to pass it to the receive side,
118 // instead of creating a new OwnerRRef. This is done by adding the OwnerRRef
119 // into owners_. However, that alone is not enough, as it could be deleted
120 // when all UserRRef die, which would then remove the OwnerRRef from owners_
121 // and this could happen before the self remote call finishes. To prevent
122 // that, this API adds the RRefId as a ForkId, which will then delete the
123 // ForkId when the self remote is done.
124 void addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref);
125
126 // Register a fork of the ``OwnerRRef``, and inserts a intrusive_ptr of the
127 // ``OwnerRRef`` in a map to keep it alive.
128 void addForkOfOwner(const RRefId& rrefId, const ForkId& forkId);
129 // Performs the same function as addForkOfOwner but ignores duplicate
130 // requests. This idempotent function is used with RREF_FORK_REQUEST calls,
131 // whereas all other message types use the non-idempotent variant.
132 void addForkOfOwnerIfNotPresent(const RRefId& rrefId, const ForkId& forkId);
133 // Delete a fork of the ``OwnerRRef``. NB: this could trigger deletion on the
134 // IValue or py::object. For the later, this method will acquire GIL.
135 // NB: If this fork deletion triggered deleting OwnerRRef, this method will
136 // return a shared_ptr to the OwnerRRef, which is likely to be the last
137 // shared_ptr instance for it. Therefore, deleting this shared_ptr<OwnerRRef>
138 // will also trigger deleting the object it points to. If OwnerRRef holds a
139 // py::object, deleting it require GIL. The call site should guarded it with
140 // a GIL and reset the shared_ptr. The GIL-guarded deletion is intentionally
141 // left out of this function to avoid creating dependency on pybind.
142 c10::intrusive_ptr<RRef> delForkOfOwner(
143 const RRefId& rrefId,
144 const ForkId& forkId);
145
146 // Invoked when pickling an RRef to setup child/fork properly
147 RRefForkData prepareChildFork(const c10::intrusive_ptr<RRef>& rref);
148 // Invoked when unpickling an RRef to send RREF_FORK_REQUEST to owner and
149 // send RREF_CHILD_ACCEPT to the parent.
150 // NB: forkId is necessary here as the rref could be an OwnerRRef
151 void notifyOwnerAndParentOfFork(
152 const ForkId& forkId,
153 worker_id_t parent,
154 const c10::intrusive_ptr<RRef>& rref);
155
156 // When a UserRRef is forked to another worker (user or owner), it is added
157 // into pendingChildren_ to be held alive until it receives RREF_CHILD_ACCEPT
158 // from the child.
159 // NB: This is necessary for both user and owner child. As we do not have FIFO
160 // communication between workers, we need this strategy to make sure that all
161 // previously submitted rpc/remote calls are acked before sending out the
162 // RREF_USER_DELETE message. Otherwise, the OwnerRRef could be deleted too
163 // soon.
164 void addPendingChild(
165 const ForkId& forkId,
166 const c10::intrusive_ptr<RRef>& rref);
167 void delPendingChild(const ForkId& forkId);
168
169 // When a UserRRef is created, it is added into pendingUsers_ to be held alive
170 // until it receives RREF_USER_ACCEPT from the owner.
171 void addPendingUser(
172 const ForkId& forkId,
173 const c10::intrusive_ptr<RRef>& rref);
174 void delPendingUser(const ForkId& forkId);
175 void addConfirmedUser(
176 const ForkId& forkId,
177 const c10::intrusive_ptr<RRef>& rref);
178
179 // Retrieve a pending user given the fork ID. Throws if the user has already
180 // been confirmed (i.e. is no longer in the pendingUsers_ map).
181 c10::intrusive_ptr<RRef> getPendingUser(const ForkId& forkId);
182
183 // Start recroding new pending UserRRefs. All pending UserRRefs introduced
184 // after this point will be put into the thread_local userTable_, which will
185 // then be consumed and cleared in waitForThreadLocalPendingRRefs().
186 void recordThreadLocalPendingRRefs();
187 // End recording new pending UserRRefs, and clear the thread_local userTable_.
188 // Returns a Future which will be marked as completed when all pending
189 // UserRRefs in the current userTable_ are confirmed by their owners. The bool
190 // value in the Future is unused.
191 // This method is useful to make sure RRefs in user function arguments are
192 // confirmed before launching user code.
193 // NB: Callers of this method does not need to keep the returned Future alive,
194 // because this Future is already captured in callbacks of the
195 // PendingUserState. If there is no pending UserRRefs, this method returns a
196 // completed future.
197 c10::intrusive_ptr<JitFuture> waitForThreadLocalPendingRRefs();
198 // Only call this function when there are errors during a recording session,
199 // and it is likely that waitForThreadLocalPendingRRefs() cannot be invoked
200 // properly.
201 // TODO: make this a context guard
202 void clearRecordedPendingRRefsOnError();
203
204 void delUser(
205 const worker_id_t owner,
206 const RRefId& rrefId,
207 const ForkId& forkId);
208 void delAllUsersAndUnforkedOwners(std::chrono::milliseconds timeoutMillis);
209
210 std::unordered_map<std::string, std::string> getDebugInfo();
211
212 private:
213 struct PendingUserState {
214 PendingUserState(c10::intrusive_ptr<RRef> rref)
215 : rref_(std::move(rref)),
216 confirmationFuture_(c10::make_intrusive<JitFuture>(BoolType::get())) {
217 }
218
219 inline void confirm() {
220 c10::static_intrusive_pointer_cast<UserRRef>(rref_)->confirm();
221 confirmationFuture_->markCompleted();
222 }
223
224 c10::intrusive_ptr<RRef> rref_;
225 // Use Future.wait() and Future.markCompleted() to block and unblock user
226 // functions. The bool value wrapped by the future_ is not used.
227 c10::intrusive_ptr<JitFuture> confirmationFuture_;
228 };
229
230 RRefContext(std::shared_ptr<RpcAgent>);
231
232 c10::intrusive_ptr<UserRRef> createUserRRef(
233 worker_id_t ownerId,
234 const RRefId& rrefId,
235 const ForkId& forkId,
236 const TypePtr& type);
237
238 void finishForkRequest(const ForkId& forkId, worker_id_t parent);
239
240 // If there is any leak on any RRef, this method will throw an error.
241 void checkRRefLeaks(bool ignoreRRefLeak);
242
243 static std::atomic<local_id_t> nextLocalId_;
244
245 const std::shared_ptr<RpcAgent> agent_;
246 mutable std::mutex mutex_;
247 // Keep OwnerRRefs alive until there is no living UserRRefs.
248 std::unordered_map<RRefId, c10::intrusive_ptr<RRef>, RRefId::Hash> owners_;
249 // A map to track OwnerRRefs that are requested but not yet created. This can
250 // happen if the to_here() message is processed on the owner before the
251 // corresponding creator rpc.remote() message. If this happens, instead of
252 // to_here() RPC thread to block waiting for the OwnerRRef creation, the
253 // RRefContext returns a Future, so that the RPC request processing logic can
254 // attach subsequent code as a callback to that Future.
255 // NB: the OwnerRRefs in this map must be cleared when the corresponding
256 // OwnerRRef is created. Note that the values in this map are intrusive_ptrs
257 // to c10::ivalue::Future that will be marked completed with the owner RRef.
258 std::unordered_map<RRefId, c10::intrusive_ptr<JitFuture>, RRefId::Hash>
259 pendingOwners_;
260 // Tracks known living UserRRefs of an OwnerRRef
261 std::unordered_map<
262 RRefId,
263 std::unordered_set<ForkId, ForkId::Hash>,
264 RRefId::Hash>
265 forks_;
266
267 // This cond var is used by deleteAllUsers(), a event notificaton is sent if
268 // number of pending UserRRef or UserRRef children is reduced, or
269 // number of owned OwnerRRef is reduced.
270 std::condition_variable deleteAllUsersCV_;
271 // The follow 3 maps keep UserRRefs alive by holding a intrusive_ptr to the
272 // RRef instances. A UserRRef must be added into this map if any of the
273 // following two conditions is true:
274 //
275 // (1) A UserRRef has not been accepted by owner yet.
276 //
277 // It can be used or shared, but cannot be deleted, and hence kept alive
278 // in this map. A message of type RREF_USER_ACCEPT will move the
279 // corresponding RRef from pendingUsers_ map to confirmedUsers_ map.
280 std::unordered_map<ForkId, std::shared_ptr<PendingUserState>, ForkId::Hash>
281 pendingUsers_;
282 // UserRRefs are added into this map when it is confirmed by the owner.
283 // When destroying RRefContext this map helps to find local UserRRefs
284 // and send delete messages if they are still not deleted by Python
285 // garbage collection.
286 std::unordered_map<ForkId, c10::weak_intrusive_ptr<RRef>, ForkId::Hash>
287 confirmedUsers_;
288
289 // (2) A UserRRef has forked a child UserRRef which has not been accepted by
290 // the owner yet.
291 //
292 // In this case, this UserRRef cannot send out RREF_USER_DELETE message,
293 // as it could potentially trigger the OwnerRRef been deleted before the
294 // owner learns about the forked child.
295 std::unordered_map<ForkId, c10::intrusive_ptr<RRef>, ForkId::Hash>
296 pendingChildren_;
297
298 // The RRef context performs its operations through async RPC requests, in
299 // order to not block the user code. Therefore the RRef context's state may be
300 // lagging a bit behind what it is intended to be, while it waits for these
301 // requests to complete. To allow syncing when needed, we store the count of
302 // these pending requests, so that users can wait for it to reach zero.
303 std::atomic<int64_t> numPendingFutures_{0};
304
305 std::mutex destroyedMutex_;
306 bool destroyed_{false};
307
308 // Thread local states to keep UserRRefs deserialized from user function
309 // arguments.
310 static thread_local std::vector<std::shared_ptr<PendingUserState>> userTable_;
311 // A flag indicating whether subsequently created UserRRefs should be added to
312 // the thread_local userTable_. The flag is set to true before serializing
313 // RPC arguments and then set to false before running the corresponding
314 // user code. See addPendingUser and delPendingUser for more details.
315 // NB: The reason for having this flag is because addPendingUser are called in
316 // two cases, and we only want to track the 2nd case.
317 // (1) RRef as the return value: when calling rpc.remote, the UserRRef on the
318 // caller side is added to the context using addPendingUser.
319 // (2) RRef as an argument: When running an RPC using RRefs as arguments, the
320 // RRef is forwarded to the callee as new UserRRefs (if the callee is not
321 // the owner). In this case, we block running the user function until all
322 // UserRRefs are confirmed by the owner.
323 // This contract gurantees that no UserRRefs can be used remotely without
324 // confirmation. Note that, however, the UserRRef created by rpc.remote can
325 // still be passed to local functions as arguments and used there. This is by
326 // design, because this feature is especially useful when, say a master node
327 // creates multiple UserRRefs in a loop and then shares them with other nodes.
328 // Blocking every iteration in the loop until RRefs are confirmed will slow
329 // this down. This nuance on UserRRef can be interpreted as we only make
330 // exceptions for UserRRef creators. And using the UserRRef on its creator
331 // without confirmation is OK, because the creator would either call to_here
332 // or forward the UserRRef, and both would then require confirmations from the
333 // owner.
334 static thread_local bool recording_;
335};
336
337} // namespace rpc
338} // namespace distributed
339} // namespace torch
340