1 | #pragma once |
2 | |
3 | #include <c10/util/Optional.h> |
4 | #include <torch/csrc/distributed/rpc/message.h> |
5 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
6 | #include <torch/csrc/distributed/rpc/rref_impl.h> |
7 | #include <torch/csrc/distributed/rpc/types.h> |
8 | #include <torch/csrc/distributed/rpc/utils.h> |
9 | |
10 | #include <atomic> |
11 | |
12 | namespace torch { |
13 | namespace distributed { |
14 | namespace rpc { |
15 | |
16 | namespace callback { |
17 | // It's the callback for RemoteCall. |
18 | void TORCH_API |
19 | confirmPendingUser(const JitFuture& jitFuture, const ForkId& expectedForkId); |
20 | |
21 | // It's the callback for finishing creating owner rref, it returned deletedRRef, |
22 | // so that the deletedRRef can be handled under GIL in python_functions.cpp if |
23 | // deletedRRef contains python object. |
24 | c10::intrusive_ptr<RRef> TORCH_API |
25 | finishCreatingOwnerRRef(const JitFuture& jitFuture, const RRefId& rrefId); |
26 | } // namespace callback |
27 | |
28 | // Manages RRef lifetime and keeps track of RRef forks. |
29 | class TORCH_API RRefContext { |
30 | public: |
31 | static RRefContext& getInstance(); |
32 | // NB: This method must be called before destructing RRefContext singleton. |
33 | // Similar to delForkOfOwner, this method returns a vector of OwnerRRefs that |
34 | // hold py::object. The call-site is also responsible for resetting those |
35 | // shared_ptr objects with a GIL. See comments at delForkOfOwner() for more |
36 | // details. |
37 | static std::vector<c10::intrusive_ptr<RRef>> destroyInstance( |
38 | bool ignoreRRefLeak = true); |
39 | |
40 | static void handleException(const JitFuture& jitFuture); |
41 | |
42 | // handle exception without throw ::c10::Error again |
43 | static void handleExceptionSilent(const JitFuture& jitFuture); |
44 | |
45 | RRefContext(const RRefContext&) = delete; |
46 | RRefContext(RRefContext&& other) = delete; |
47 | void operator=(const RRefContext&) = delete; |
48 | RRefContext& operator=(RRefContext&& other) = delete; |
49 | |
50 | ~RRefContext(); |
51 | |
52 | // get the worker id of the current worker |
53 | inline worker_id_t getWorkerId() const { |
54 | return agent_->getWorkerInfo().id_; |
55 | } |
56 | |
57 | // get the worker name of the current worker |
58 | inline const std::string& getWorkerName() const { |
59 | return agent_->getWorkerInfo().name_; |
60 | } |
61 | |
62 | // generate a globally unique ID |
63 | inline GloballyUniqueId genGloballyUniqueId() { |
64 | return GloballyUniqueId(getWorkerId(), nextLocalId_++); |
65 | } |
66 | |
67 | inline const std::shared_ptr<RpcAgent>& agent() const { |
68 | return agent_; |
69 | } |
70 | |
71 | // create a ``UserRRef`` owned by the worker ``ownerId`` |
72 | c10::intrusive_ptr<UserRRef> createUserRRef( |
73 | worker_id_t ownerId, |
74 | const TypePtr& type); |
75 | |
76 | // Convert an RRefForkData into an RRef. This RRef could be user or owner. |
77 | // This RRef could have already existed before, or could be created in this |
78 | // method, we pass type here to validate or help the rref creation. |
79 | c10::intrusive_ptr<RRef> getOrCreateRRef( |
80 | const RRefForkData& rfd, |
81 | const TypePtr& type); |
82 | |
83 | // Get the ``OwnerRRef`` of id ``rrefId``. If it does not exist, create a new |
84 | // one. This function is called in two places: |
85 | // 1. when processing ``rpc.remote()``, i.e., ``SCRIPT_REMOTE_CALL`` |
86 | // ``PYTHON_REMOTE_CALL``. |
87 | // 2. when unpickling ``OwnerRRef``. |
88 | // What's common in these two cases are, 1) the RRefId is already generated |
89 | // 2) the TypePtr is presented. So it can always create the ``OwnerRRef`` if |
90 | // it is not yet available. |
91 | c10::intrusive_ptr<OwnerRRef> getOrCreateOwnerRRef( |
92 | const RRefId& rrefId, |
93 | const TypePtr& type); |
94 | |
95 | // Create an empty owner rref of type. |
96 | // This method is called to first time generate an ``OwnerRRef``, e.g., |
97 | // 1) ``rpc.RRef(obj)`` |
98 | // 2) create the ``OwnerRRef`` on `rpc.remote()` caller side. |
99 | // What's common in these two cases are, 1) the RRefId hasn't been generated |
100 | // 2) the TypePtr is presented. |
101 | c10::intrusive_ptr<OwnerRRef> createOwnerRRef(const TypePtr& type); |
102 | |
103 | // Returns a Future of the OwnerRRef, which will be marked completed when |
104 | // ``OwnerRRef`` is created. This method is used when the TypePtr is not |
105 | // available, e.g., when processing to_here(). The forceCreated flag can be |
106 | // used to ensure that the rref is created on the owner, otherwise throw in |
107 | // cases where the user of this API expects this to return a completed future. |
108 | // Note that the return value is a intrusive_ptr to a c10::ivalue::Future that |
109 | // holds the RRef. |
110 | c10::intrusive_ptr<JitFuture> getOwnerRRef( |
111 | const RRefId& rrefId, |
112 | bool forceCreated = false); |
113 | |
114 | // Adding the RRefId of an OwnerRRef into the forks_ map. This is useful when |
115 | // making a remote call to self, which as for now, still goes through serde |
116 | // and invokes request callback. In this case, the OwnerRRef has already been |
117 | // created on the send side, and we need to pass it to the receive side, |
118 | // instead of creating a new OwnerRRef. This is done by adding the OwnerRRef |
119 | // into owners_. However, that alone is not enough, as it could be deleted |
120 | // when all UserRRef die, which would then remove the OwnerRRef from owners_ |
121 | // and this could happen before the self remote call finishes. To prevent |
122 | // that, this API adds the RRefId as a ForkId, which will then delete the |
123 | // ForkId when the self remote is done. |
124 | void addSelfAsFork(c10::intrusive_ptr<OwnerRRef>& rref); |
125 | |
126 | // Register a fork of the ``OwnerRRef``, and inserts a intrusive_ptr of the |
127 | // ``OwnerRRef`` in a map to keep it alive. |
128 | void addForkOfOwner(const RRefId& rrefId, const ForkId& forkId); |
129 | // Performs the same function as addForkOfOwner but ignores duplicate |
130 | // requests. This idempotent function is used with RREF_FORK_REQUEST calls, |
131 | // whereas all other message types use the non-idempotent variant. |
132 | void addForkOfOwnerIfNotPresent(const RRefId& rrefId, const ForkId& forkId); |
133 | // Delete a fork of the ``OwnerRRef``. NB: this could trigger deletion on the |
134 | // IValue or py::object. For the later, this method will acquire GIL. |
135 | // NB: If this fork deletion triggered deleting OwnerRRef, this method will |
136 | // return a shared_ptr to the OwnerRRef, which is likely to be the last |
137 | // shared_ptr instance for it. Therefore, deleting this shared_ptr<OwnerRRef> |
138 | // will also trigger deleting the object it points to. If OwnerRRef holds a |
139 | // py::object, deleting it require GIL. The call site should guarded it with |
140 | // a GIL and reset the shared_ptr. The GIL-guarded deletion is intentionally |
141 | // left out of this function to avoid creating dependency on pybind. |
142 | c10::intrusive_ptr<RRef> delForkOfOwner( |
143 | const RRefId& rrefId, |
144 | const ForkId& forkId); |
145 | |
146 | // Invoked when pickling an RRef to setup child/fork properly |
147 | RRefForkData prepareChildFork(const c10::intrusive_ptr<RRef>& rref); |
148 | // Invoked when unpickling an RRef to send RREF_FORK_REQUEST to owner and |
149 | // send RREF_CHILD_ACCEPT to the parent. |
150 | // NB: forkId is necessary here as the rref could be an OwnerRRef |
151 | void notifyOwnerAndParentOfFork( |
152 | const ForkId& forkId, |
153 | worker_id_t parent, |
154 | const c10::intrusive_ptr<RRef>& rref); |
155 | |
156 | // When a UserRRef is forked to another worker (user or owner), it is added |
157 | // into pendingChildren_ to be held alive until it receives RREF_CHILD_ACCEPT |
158 | // from the child. |
159 | // NB: This is necessary for both user and owner child. As we do not have FIFO |
160 | // communication between workers, we need this strategy to make sure that all |
161 | // previously submitted rpc/remote calls are acked before sending out the |
162 | // RREF_USER_DELETE message. Otherwise, the OwnerRRef could be deleted too |
163 | // soon. |
164 | void addPendingChild( |
165 | const ForkId& forkId, |
166 | const c10::intrusive_ptr<RRef>& rref); |
167 | void delPendingChild(const ForkId& forkId); |
168 | |
169 | // When a UserRRef is created, it is added into pendingUsers_ to be held alive |
170 | // until it receives RREF_USER_ACCEPT from the owner. |
171 | void addPendingUser( |
172 | const ForkId& forkId, |
173 | const c10::intrusive_ptr<RRef>& rref); |
174 | void delPendingUser(const ForkId& forkId); |
175 | void addConfirmedUser( |
176 | const ForkId& forkId, |
177 | const c10::intrusive_ptr<RRef>& rref); |
178 | |
179 | // Retrieve a pending user given the fork ID. Throws if the user has already |
180 | // been confirmed (i.e. is no longer in the pendingUsers_ map). |
181 | c10::intrusive_ptr<RRef> getPendingUser(const ForkId& forkId); |
182 | |
183 | // Start recroding new pending UserRRefs. All pending UserRRefs introduced |
184 | // after this point will be put into the thread_local userTable_, which will |
185 | // then be consumed and cleared in waitForThreadLocalPendingRRefs(). |
186 | void recordThreadLocalPendingRRefs(); |
187 | // End recording new pending UserRRefs, and clear the thread_local userTable_. |
188 | // Returns a Future which will be marked as completed when all pending |
189 | // UserRRefs in the current userTable_ are confirmed by their owners. The bool |
190 | // value in the Future is unused. |
191 | // This method is useful to make sure RRefs in user function arguments are |
192 | // confirmed before launching user code. |
193 | // NB: Callers of this method does not need to keep the returned Future alive, |
194 | // because this Future is already captured in callbacks of the |
195 | // PendingUserState. If there is no pending UserRRefs, this method returns a |
196 | // completed future. |
197 | c10::intrusive_ptr<JitFuture> waitForThreadLocalPendingRRefs(); |
198 | // Only call this function when there are errors during a recording session, |
199 | // and it is likely that waitForThreadLocalPendingRRefs() cannot be invoked |
200 | // properly. |
201 | // TODO: make this a context guard |
202 | void clearRecordedPendingRRefsOnError(); |
203 | |
204 | void delUser( |
205 | const worker_id_t owner, |
206 | const RRefId& rrefId, |
207 | const ForkId& forkId); |
208 | void delAllUsersAndUnforkedOwners(std::chrono::milliseconds timeoutMillis); |
209 | |
210 | std::unordered_map<std::string, std::string> getDebugInfo(); |
211 | |
212 | private: |
213 | struct PendingUserState { |
214 | PendingUserState(c10::intrusive_ptr<RRef> rref) |
215 | : rref_(std::move(rref)), |
216 | confirmationFuture_(c10::make_intrusive<JitFuture>(BoolType::get())) { |
217 | } |
218 | |
219 | inline void confirm() { |
220 | c10::static_intrusive_pointer_cast<UserRRef>(rref_)->confirm(); |
221 | confirmationFuture_->markCompleted(); |
222 | } |
223 | |
224 | c10::intrusive_ptr<RRef> rref_; |
225 | // Use Future.wait() and Future.markCompleted() to block and unblock user |
226 | // functions. The bool value wrapped by the future_ is not used. |
227 | c10::intrusive_ptr<JitFuture> confirmationFuture_; |
228 | }; |
229 | |
230 | RRefContext(std::shared_ptr<RpcAgent>); |
231 | |
232 | c10::intrusive_ptr<UserRRef> createUserRRef( |
233 | worker_id_t ownerId, |
234 | const RRefId& rrefId, |
235 | const ForkId& forkId, |
236 | const TypePtr& type); |
237 | |
238 | void finishForkRequest(const ForkId& forkId, worker_id_t parent); |
239 | |
240 | // If there is any leak on any RRef, this method will throw an error. |
241 | void checkRRefLeaks(bool ignoreRRefLeak); |
242 | |
243 | static std::atomic<local_id_t> nextLocalId_; |
244 | |
245 | const std::shared_ptr<RpcAgent> agent_; |
246 | mutable std::mutex mutex_; |
247 | // Keep OwnerRRefs alive until there is no living UserRRefs. |
248 | std::unordered_map<RRefId, c10::intrusive_ptr<RRef>, RRefId::Hash> owners_; |
249 | // A map to track OwnerRRefs that are requested but not yet created. This can |
250 | // happen if the to_here() message is processed on the owner before the |
251 | // corresponding creator rpc.remote() message. If this happens, instead of |
252 | // to_here() RPC thread to block waiting for the OwnerRRef creation, the |
253 | // RRefContext returns a Future, so that the RPC request processing logic can |
254 | // attach subsequent code as a callback to that Future. |
255 | // NB: the OwnerRRefs in this map must be cleared when the corresponding |
256 | // OwnerRRef is created. Note that the values in this map are intrusive_ptrs |
257 | // to c10::ivalue::Future that will be marked completed with the owner RRef. |
258 | std::unordered_map<RRefId, c10::intrusive_ptr<JitFuture>, RRefId::Hash> |
259 | pendingOwners_; |
260 | // Tracks known living UserRRefs of an OwnerRRef |
261 | std::unordered_map< |
262 | RRefId, |
263 | std::unordered_set<ForkId, ForkId::Hash>, |
264 | RRefId::Hash> |
265 | forks_; |
266 | |
267 | // This cond var is used by deleteAllUsers(), a event notificaton is sent if |
268 | // number of pending UserRRef or UserRRef children is reduced, or |
269 | // number of owned OwnerRRef is reduced. |
270 | std::condition_variable deleteAllUsersCV_; |
271 | // The follow 3 maps keep UserRRefs alive by holding a intrusive_ptr to the |
272 | // RRef instances. A UserRRef must be added into this map if any of the |
273 | // following two conditions is true: |
274 | // |
275 | // (1) A UserRRef has not been accepted by owner yet. |
276 | // |
277 | // It can be used or shared, but cannot be deleted, and hence kept alive |
278 | // in this map. A message of type RREF_USER_ACCEPT will move the |
279 | // corresponding RRef from pendingUsers_ map to confirmedUsers_ map. |
280 | std::unordered_map<ForkId, std::shared_ptr<PendingUserState>, ForkId::Hash> |
281 | pendingUsers_; |
282 | // UserRRefs are added into this map when it is confirmed by the owner. |
283 | // When destroying RRefContext this map helps to find local UserRRefs |
284 | // and send delete messages if they are still not deleted by Python |
285 | // garbage collection. |
286 | std::unordered_map<ForkId, c10::weak_intrusive_ptr<RRef>, ForkId::Hash> |
287 | confirmedUsers_; |
288 | |
289 | // (2) A UserRRef has forked a child UserRRef which has not been accepted by |
290 | // the owner yet. |
291 | // |
292 | // In this case, this UserRRef cannot send out RREF_USER_DELETE message, |
293 | // as it could potentially trigger the OwnerRRef been deleted before the |
294 | // owner learns about the forked child. |
295 | std::unordered_map<ForkId, c10::intrusive_ptr<RRef>, ForkId::Hash> |
296 | pendingChildren_; |
297 | |
298 | // The RRef context performs its operations through async RPC requests, in |
299 | // order to not block the user code. Therefore the RRef context's state may be |
300 | // lagging a bit behind what it is intended to be, while it waits for these |
301 | // requests to complete. To allow syncing when needed, we store the count of |
302 | // these pending requests, so that users can wait for it to reach zero. |
303 | std::atomic<int64_t> numPendingFutures_{0}; |
304 | |
305 | std::mutex destroyedMutex_; |
306 | bool destroyed_{false}; |
307 | |
308 | // Thread local states to keep UserRRefs deserialized from user function |
309 | // arguments. |
310 | static thread_local std::vector<std::shared_ptr<PendingUserState>> userTable_; |
311 | // A flag indicating whether subsequently created UserRRefs should be added to |
312 | // the thread_local userTable_. The flag is set to true before serializing |
313 | // RPC arguments and then set to false before running the corresponding |
314 | // user code. See addPendingUser and delPendingUser for more details. |
315 | // NB: The reason for having this flag is because addPendingUser are called in |
316 | // two cases, and we only want to track the 2nd case. |
317 | // (1) RRef as the return value: when calling rpc.remote, the UserRRef on the |
318 | // caller side is added to the context using addPendingUser. |
319 | // (2) RRef as an argument: When running an RPC using RRefs as arguments, the |
320 | // RRef is forwarded to the callee as new UserRRefs (if the callee is not |
321 | // the owner). In this case, we block running the user function until all |
322 | // UserRRefs are confirmed by the owner. |
323 | // This contract gurantees that no UserRRefs can be used remotely without |
324 | // confirmation. Note that, however, the UserRRef created by rpc.remote can |
325 | // still be passed to local functions as arguments and used there. This is by |
326 | // design, because this feature is especially useful when, say a master node |
327 | // creates multiple UserRRefs in a loop and then shares them with other nodes. |
328 | // Blocking every iteration in the loop until RRefs are confirmed will slow |
329 | // this down. This nuance on UserRRef can be interpreted as we only make |
330 | // exceptions for UserRRef creators. And using the UserRRef on its creator |
331 | // without confirmation is OK, because the creator would either call to_here |
332 | // or forward the UserRRef, and both would then require confirmations from the |
333 | // owner. |
334 | static thread_local bool recording_; |
335 | }; |
336 | |
337 | } // namespace rpc |
338 | } // namespace distributed |
339 | } // namespace torch |
340 | |