1 | #include <torch/csrc/distributed/rpc/request_callback_no_python.h> |
2 | |
3 | #include <c10/core/StreamGuard.h> |
4 | #include <torch/csrc/distributed/autograd/context/container.h> |
5 | #include <torch/csrc/distributed/autograd/engine/dist_engine.h> |
6 | #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h> |
7 | #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h> |
8 | #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h> |
9 | #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h> |
10 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h> |
11 | #include <torch/csrc/distributed/autograd/utils.h> |
12 | #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h> |
13 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
14 | #include <torch/csrc/distributed/rpc/rref_context.h> |
15 | #include <torch/csrc/distributed/rpc/rref_proto.h> |
16 | #include <torch/csrc/distributed/rpc/script_resp.h> |
17 | #include <torch/csrc/distributed/rpc/utils.h> |
18 | |
19 | namespace torch { |
20 | namespace distributed { |
21 | namespace rpc { |
22 | |
23 | using namespace torch::distributed::autograd; |
24 | using namespace torch::autograd::profiler; |
25 | |
26 | // When request message has autograd info, processMessage() will set up valid |
27 | // current context id properly. This struct is used to clean up current context |
28 | // id after processMessage() is done. |
29 | struct DistAutogradContextGuard { |
30 | explicit DistAutogradContextGuard(int64_t ctxId) { |
31 | auto& container = DistAutogradContainer::getInstance(); |
32 | prevCtxId_ = container.currentContextId(); |
33 | container.forceCurrentContextId(ctxId); |
34 | } |
35 | ~DistAutogradContextGuard() { |
36 | auto& container = DistAutogradContainer::getInstance(); |
37 | container.forceCurrentContextId(prevCtxId_); |
38 | } |
39 | |
40 | int64_t prevCtxId_; |
41 | }; |
42 | |
43 | std::unique_ptr<RpcCommandBase> RequestCallbackNoPython:: |
44 | deserializePythonRpcCommand( |
45 | std::unique_ptr<RpcCommandBase> rpc, |
46 | const MessageType& messageType) const { |
47 | TORCH_CHECK( |
48 | messageType != MessageType::PYTHON_CALL && |
49 | messageType != MessageType::PYTHON_REMOTE_CALL, |
50 | "Python calls are not supported!" ); |
51 | return rpc; |
52 | } |
53 | |
54 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processMessage( |
55 | Message& request, |
56 | std::vector<c10::Stream> streams) const { |
57 | // We need two futures here because it could pause twice when processing a |
58 | // RPC message: |
59 | // 1) waiting for all RRefs in the arguments to become confirmed; |
60 | // 2) waiting for processRpc to finish. |
61 | auto& rrefContext = RRefContext::getInstance(); |
62 | try { |
63 | rrefContext.recordThreadLocalPendingRRefs(); |
64 | // Deserialize PythonUDF here to trigger RRef unpickling |
65 | std::unique_ptr<RpcCommandBase> rpc = deserializePythonRpcCommand( |
66 | deserializeRequest(request), request.type()); |
67 | auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs(); |
68 | |
69 | auto retFuture = rrefsReadyFuture->thenAsync( |
70 | [this, |
71 | // std::function must be copyable, hence hae to cast the unique_ptr to |
72 | // a shared_ptr here. |
73 | rpc = (std::shared_ptr<RpcCommandBase>)std::move(rpc), |
74 | messageType = request.type(), |
75 | streams = std::move(streams)](JitFuture& /* unused */) mutable { |
76 | // The cost of pre-request check is minimal thanks to |
77 | // std::shared_lock. The cost is in magnitude |
78 | // of 10us. |
79 | auto serverProcessGlobalProfilerStateStackEntryPtr = |
80 | profiler::processglobal::StateStackEntry::current(); |
81 | // If server global profiler is enabled, we futher pay the |
82 | // cost of thread local profiler state initialization. |
83 | if (serverProcessGlobalProfilerStateStackEntryPtr) { |
84 | // Initialize thread-local profiler state from process-global |
85 | // profiler state. |
86 | enableProfilerLegacy( |
87 | serverProcessGlobalProfilerStateStackEntryPtr->statePtr() |
88 | ->config()); |
89 | } |
90 | |
91 | auto retFuture = |
92 | processRpcWithErrors(*rpc, messageType, std::move(streams)); |
93 | |
94 | // Response message has been sent at this moment, this post-response |
95 | // work doesn't affect RPC trip time. |
96 | if (serverProcessGlobalProfilerStateStackEntryPtr) { |
97 | // Restore thread-local profiler state. |
98 | thread_event_lists event_lists = disableProfilerLegacy(); |
99 | // Put thread_local event_lists into the process-global profiler |
100 | // state. |
101 | profiler::processglobal::pushResultRecursive( |
102 | serverProcessGlobalProfilerStateStackEntryPtr, event_lists); |
103 | } |
104 | |
105 | return retFuture; |
106 | }, |
107 | c10::getCustomClassType<c10::intrusive_ptr<Message>>()); |
108 | |
109 | auto retFutureWithMessageId = retFuture->then( |
110 | [id = request.id()](JitFuture& future) { |
111 | c10::intrusive_ptr<Message> message = |
112 | future.value().toCustomClass<Message>(); |
113 | message->setId(id); |
114 | return withStorages(message); |
115 | }, |
116 | c10::getCustomClassType<c10::intrusive_ptr<Message>>()); |
117 | |
118 | return retFutureWithMessageId; |
119 | } catch (std::exception& e) { |
120 | rrefContext.clearRecordedPendingRRefsOnError(); |
121 | return asFuture(handleError(e, request.type(), request.id())); |
122 | } |
123 | } |
124 | |
125 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpcWithErrors( |
126 | RpcCommandBase& rpc, |
127 | const MessageType& messageType, |
128 | std::vector<c10::Stream> streams) const { |
129 | try { |
130 | return processRpc(rpc, messageType, std::move(streams)); |
131 | } catch (std::exception& e) { |
132 | // Pass a dummy message ID since it will be overwritten anyways. |
133 | return asFuture(handleError(e, messageType, -1)); |
134 | } |
135 | } |
136 | |
137 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptCall( |
138 | RpcCommandBase& rpc, |
139 | std::vector<c10::Stream> streams) const { |
140 | auto& scriptCall = static_cast<ScriptCall&>(rpc); |
141 | |
142 | TORCH_CHECK( |
143 | scriptCall.hasOp(), "Only supports the case where ScriptCall has an op" ); |
144 | auto future = runJitOperator( |
145 | *scriptCall.op(), scriptCall.stackRef(), std::move(streams)); |
146 | |
147 | return future->then( |
148 | [](JitFuture& future) { |
149 | return withStorages(ScriptResp(future.value()).toMessage()); |
150 | }, |
151 | c10::getCustomClassType<c10::intrusive_ptr<Message>>()); |
152 | } |
153 | |
154 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonCall( |
155 | RpcCommandBase& rpc, |
156 | std::vector<c10::Stream> /* unused */) const { |
157 | C10_THROW_ERROR(Error, "Python call not supported!" ); |
158 | } |
159 | |
160 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonRemoteCall( |
161 | RpcCommandBase& rpc, |
162 | std::vector<c10::Stream> /* unused */) const { |
163 | C10_THROW_ERROR(Error, "Python call not supported!" ); |
164 | } |
165 | |
166 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::assignOwnerRRef( |
167 | const RRefId& rrefId, |
168 | const RRefId& forkId, |
169 | c10::intrusive_ptr<JitFuture> valueFuture) const { |
170 | auto& ctx = RRefContext::getInstance(); |
171 | |
172 | c10::intrusive_ptr<OwnerRRef> ownerRRef; |
173 | if (rrefId == forkId) { |
174 | // Creating an owner RRef on self, should already exist in owners map |
175 | ownerRRef = |
176 | fromRRefInterface(ctx.getOwnerRRef(rrefId, /* forceCreated */ true) |
177 | ->constValue() |
178 | .toRRef()); |
179 | } else { |
180 | ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, valueFuture->elementType()); |
181 | // Caller is a user and callee is the owner, add fork |
182 | // |
183 | // NB: rrefId == forkId is true if and only if calling remote to self. |
184 | // In that case both the caller and the callee will access the |
185 | // OwnerRRef. Hence, on the callee side (here), it should not call |
186 | // addForkOfOwner as it is not a fork. To allow callee to distinguish |
187 | // when this request is sent to self, the caller will set forkId using |
188 | // rrefId (OwnerRRef does not have a forkId anyway). |
189 | ctx.addForkOfOwner(rrefId, forkId); |
190 | } |
191 | |
192 | return valueFuture->then( |
193 | [ownerRRef, rrefId, forkId](JitFuture& future) { |
194 | if (future.hasError()) { |
195 | ownerRRef->setError(future.exception_ptr()); |
196 | } else { |
197 | ownerRRef->setValue(future.value()); |
198 | } |
199 | return withStorages(RemoteRet(rrefId, forkId).toMessage()); |
200 | }, |
201 | c10::getCustomClassType<c10::intrusive_ptr<Message>>()); |
202 | } |
203 | |
204 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptRemoteCall( |
205 | RpcCommandBase& rpc, |
206 | std::vector<c10::Stream> streams) const { |
207 | auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc); |
208 | |
209 | TORCH_CHECK( |
210 | scriptRemoteCall.hasOp(), "ScriptRemoteCall needs to have an op!" ); |
211 | auto future = runJitOperator( |
212 | *scriptRemoteCall.op(), scriptRemoteCall.stackRef(), std::move(streams)); |
213 | |
214 | return assignOwnerRRef( |
215 | scriptRemoteCall.retRRefId(), |
216 | scriptRemoteCall.retForkId(), |
217 | std::move(future)); |
218 | } |
219 | |
220 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::retrieveOwnerRRef( |
221 | const RRefId& rrefId) const { |
222 | auto& ctx = RRefContext::getInstance(); |
223 | |
224 | auto rrefFuture = ctx.getOwnerRRef(rrefId); |
225 | |
226 | at::TypePtr type = rrefFuture->elementType(); |
227 | TORCH_INTERNAL_ASSERT(type->kind() == at::RRefType::Kind); |
228 | return rrefFuture->thenAsync( |
229 | [](JitFuture& rrefFuture) { |
230 | c10::intrusive_ptr<OwnerRRef> rref = |
231 | fromRRefInterface(rrefFuture.value().toRRef()); |
232 | return rref->getFuture(); |
233 | }, |
234 | type->cast<at::RRefType>()->getElementType()); |
235 | } |
236 | |
237 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython:: |
238 | processScriptRRefFetchCall(RpcCommandBase& rpc) const { |
239 | auto& srf = static_cast<ScriptRRefFetchCall&>(rpc); |
240 | |
241 | auto future = retrieveOwnerRRef(srf.rrefId()); |
242 | |
243 | return future->then( |
244 | [](JitFuture& future) { |
245 | return withStorages(ScriptRRefFetchRet({future.value()}).toMessage()); |
246 | }, |
247 | c10::getCustomClassType<c10::intrusive_ptr<Message>>()); |
248 | } |
249 | |
250 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython:: |
251 | processPythonRRefFetchCall(RpcCommandBase& rpc) const { |
252 | C10_THROW_ERROR(Error, "Python call not supported!" ); |
253 | } |
254 | |
255 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefUserDelete( |
256 | RpcCommandBase& rpc) const { |
257 | auto& rud = static_cast<RRefUserDelete&>(rpc); |
258 | auto& ctx = RRefContext::getInstance(); |
259 | auto deletedRRef = ctx.delForkOfOwner(rud.rrefId(), rud.forkId()); |
260 | handleRRefDelete(deletedRRef); |
261 | return asFuture(RRefAck().toMessage()); |
262 | } |
263 | |
264 | void RequestCallbackNoPython::handleRRefDelete( |
265 | c10::intrusive_ptr<RRef>& rref) const { |
266 | TORCH_CHECK(!rref->isPyObj(), "RRefs with python objects not supported!" ); |
267 | } |
268 | |
269 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefChildAccept( |
270 | RpcCommandBase& rpc) const { |
271 | auto& rca = static_cast<RRefChildAccept&>(rpc); |
272 | auto& ctx = RRefContext::getInstance(); |
273 | ctx.delPendingChild(rca.forkId()); |
274 | return asFuture(RRefAck().toMessage()); |
275 | } |
276 | |
277 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefForkRequest( |
278 | RpcCommandBase& rpc) const { |
279 | auto& rfr = static_cast<RRefForkRequest&>(rpc); |
280 | auto& ctx = RRefContext::getInstance(); |
281 | ctx.addForkOfOwnerIfNotPresent(rfr.rrefId(), rfr.forkId()); |
282 | return asFuture(RRefAck().toMessage()); |
283 | } |
284 | |
285 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython:: |
286 | processForwardAutogradReq( |
287 | RpcCommandBase& rpc, |
288 | std::vector<c10::Stream> streams) const { |
289 | auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc); |
290 | |
291 | // Need to reverse the device map for the backward pass of distributed |
292 | // autograd. |
293 | DeviceMap reverseDeviceMap; |
294 | for (const auto& mapEntry : rpcWithAutograd.deviceMap()) { |
295 | reverseDeviceMap.insert({mapEntry.second, mapEntry.first}); |
296 | } |
297 | |
298 | // Attach 'recv' autograd function. |
299 | auto autogradContext = addRecvRpcBackward( |
300 | rpcWithAutograd.autogradMetadata(), |
301 | rpcWithAutograd.tensors(), |
302 | rpcWithAutograd.fromWorkerId(), |
303 | reverseDeviceMap); |
304 | // For this recv thread on server side, before processRpc(), |
305 | // set current_context_id_ to be context_id passed from client. |
306 | // In this way, if there is nested rpc call in python rpc call, original |
307 | // context_id from client can be passed in the chain calls. |
308 | TORCH_INTERNAL_ASSERT( |
309 | autogradContext != nullptr, |
310 | "autogradContext is nullptr, FORWARD_AUTOGRAD_REQ should always get " |
311 | "or create valid autogradContext in addRecvRpcBackward." ); |
312 | |
313 | DistAutogradContextGuard ctxGuard(autogradContext->contextId()); |
314 | |
315 | // Process the original RPC. |
316 | auto wrappedMessageType = rpcWithAutograd.wrappedMessageType(); |
317 | // Kick off processing for the nested RPC command. |
318 | // wrappedRpcResponseFuture will be a Future<T> to the result. |
319 | auto wrappedRpcResponseFuture = processRpc( |
320 | rpcWithAutograd.wrappedRpc(), wrappedMessageType, std::move(streams)); |
321 | |
322 | auto fromWorkerId = rpcWithAutograd.fromWorkerId(); |
323 | // The original future needs to be marked as completed when the wrapped |
324 | // one completes, with the autograd context information wrapped. |
325 | auto responseFuture = wrappedRpcResponseFuture->then( |
326 | [fromWorkerId, ctxId = autogradContext->contextId()]( |
327 | JitFuture& wrappedRpcResponseFuture) { |
328 | // As this callback can be invoked by a different thread, we have to |
329 | // make sure that the thread_local states in the previous thread is |
330 | // correctly propagated. |
331 | // NB: The execution of TorchScript functions can also run on a |
332 | // different thread, which is addressed by |
333 | // https://github.com/pytorch/pytorch/pull/36395 |
334 | // NB: when adding async UDF support, we should also propagate |
335 | // thread_local states there. |
336 | // TODO: Land on a general solution for RPC ThreadLocalState. See |
337 | // https://github.com/pytorch/pytorch/issues/38510 |
338 | DistAutogradContextGuard cbCtxGuard(ctxId); |
339 | |
340 | if (wrappedRpcResponseFuture.hasError()) { |
341 | // Propagate error to responseFuture if we had one. |
342 | std::rethrow_exception(wrappedRpcResponseFuture.exception_ptr()); |
343 | } else { |
344 | auto msg = getMessageWithAutograd( |
345 | fromWorkerId, |
346 | wrappedRpcResponseFuture.value().toCustomClass<Message>(), |
347 | MessageType::FORWARD_AUTOGRAD_RESP); |
348 | return withStorages(std::move(msg)); |
349 | } |
350 | }, |
351 | c10::getCustomClassType<c10::intrusive_ptr<Message>>()); |
352 | |
353 | return responseFuture; |
354 | } |
355 | |
356 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython:: |
357 | processBackwardAutogradReq( |
358 | RpcCommandBase& rpc, |
359 | std::vector<c10::Stream> streams) const { |
360 | c10::MultiStreamGuard guard(streams); |
361 | auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc); |
362 | const auto& autogradMetadata = gradientsCall.getAutogradMetadata(); |
363 | |
364 | // Retrieve the appropriate autograd context. |
365 | auto autogradContext = DistAutogradContainer::getInstance().retrieveContext( |
366 | autogradMetadata.autogradContextId); |
367 | |
368 | // Lookup the appropriate 'send' function to enqueue. |
369 | std::shared_ptr<SendRpcBackward> sendFunction = |
370 | autogradContext->retrieveSendFunction(autogradMetadata.autogradMessageId); |
371 | |
372 | // Attach the gradients to the send function. |
373 | sendFunction->setGrads(gradientsCall.getGrads()); |
374 | |
375 | // Now execute the autograd graph using the "distributed engine." |
376 | auto execFuture = DistEngine::getInstance().executeSendFunctionAsync( |
377 | autogradContext, sendFunction, gradientsCall.retainGraph()); |
378 | |
379 | // Our response is satisfied when the rpcs come back. |
380 | return execFuture->then( |
381 | [](JitFuture& execFuture) { |
382 | if (execFuture.hasError()) { |
383 | std::rethrow_exception(execFuture.exception_ptr()); |
384 | } else { |
385 | return withStorages(PropagateGradientsResp().toMessage()); |
386 | } |
387 | }, |
388 | c10::getCustomClassType<c10::intrusive_ptr<Message>>()); |
389 | } |
390 | |
391 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython:: |
392 | processCleanupAutogradContextReq(RpcCommandBase& rpc) const { |
393 | auto& cleanupContextReq = static_cast<CleanupAutogradContextReq&>(rpc); |
394 | auto cleanupContextId = cleanupContextReq.getContextId(); |
395 | // release the context if it still exists on this thread. We need to |
396 | // check if it exists since it may have been deleted by an in-flight |
397 | // RPC. This can create nested RPCs if there are other nodes that get |
398 | // notified to clean up their context. |
399 | DistAutogradContainer::getInstance().releaseContextIfPresent( |
400 | cleanupContextId); |
401 | return asFuture(CleanupAutogradContextResp().toMessage()); |
402 | } |
403 | |
404 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython:: |
405 | processRunWithProfilingReq(RpcCommandBase& rpc) const { |
406 | auto& rpcWithProfilingReq = static_cast<RpcWithProfilingReq&>(rpc); |
407 | auto wrappedMsgType = rpcWithProfilingReq.wrappedMessageType(); |
408 | auto profilingConfig = rpcWithProfilingReq.getProfilingConfig(); |
409 | |
410 | if (profilingConfig.state == ProfilerState::KINETO || |
411 | profilingConfig.state == ProfilerState::KINETO_GPU_FALLBACK) { |
412 | profilingConfig = ProfilerConfig( |
413 | ProfilerState::CPU, |
414 | profilingConfig.report_input_shapes, |
415 | profilingConfig.profile_memory); |
416 | } |
417 | |
418 | // If requested with CUDA from caller but CUDA is not available on this |
419 | // machine, fallback to CPU and log a warning instead of crashing. |
420 | if (profilingConfig.state == ProfilerState::CUDA && !this->cudaAvailable()) { |
421 | profilingConfig = ProfilerConfig( |
422 | ProfilerState::CPU, |
423 | profilingConfig.report_input_shapes, |
424 | profilingConfig.profile_memory); |
425 | |
426 | LOG(WARNING) << "Profiler was requested to be enabled with CUDA on this " |
427 | "node, but CUDA is not available. " |
428 | << "Falling back to CPU profiling only." ; |
429 | } |
430 | TORCH_INTERNAL_ASSERT( |
431 | profilingConfig.state != ProfilerState::CUDA || this->cudaAvailable(), |
432 | "Profiler state set to CUDA but CUDA not available." ); |
433 | const auto profilingKeyId = rpcWithProfilingReq.getProfilingId(); |
434 | // Enable the profiler with the config from the sender. |
435 | // When enabling on the main thread, ensure profiler states are cleaned |
436 | // up, but defer consolidation of all profiled events to the continuation |
437 | // below. |
438 | ProfilerDisableOptions requestThreadOptions( |
439 | true /* cleanup TLS state */, false /* consolidate events */); |
440 | { |
441 | TLSLegacyProfilerGuard g( |
442 | profilingConfig, c10::nullopt, requestThreadOptions); |
443 | TORCH_INTERNAL_ASSERT( |
444 | profilerEnabled(), "Expected profiler to be enabled!" ); |
445 | // Kick off processing for nested work and get Future<T> result in |
446 | // wrappedRpcResponseFuture |
447 | auto wrappedRpcResponseFuture = processRpc( |
448 | rpcWithProfilingReq.wrappedRpc(), |
449 | wrappedMsgType, |
450 | {}); // TODO: https://github.com/pytorch/pytorch/issues/55757 |
451 | |
452 | auto responseFuture = wrappedRpcResponseFuture->then( |
453 | at::wrapPropagateTLSState([profilingKeyId, profilingConfig]( |
454 | JitFuture& wrappedRpcResponseFuture) { |
455 | std::vector<LegacyEvent> profiledEvents; |
456 | // Defer consolidation of profiler events until async work has |
457 | // completed (such as async UDF) |
458 | |
459 | TORCH_INTERNAL_ASSERT( |
460 | profilerEnabled(), "Expected profiler to be enabled!" ); |
461 | |
462 | // On continuation thread, don't clean up profiler states, since |
463 | // they will be cleaned up by main thread, and consolidate all |
464 | // events so we obtain asynchronously run events. |
465 | ProfilerDisableOptions opts(false, true); |
466 | auto event_lists = disableProfilerLegacy(opts); |
467 | if (wrappedRpcResponseFuture.hasError()) { |
468 | // Propagate error |
469 | // No need to propagate remote events in the case of an error. |
470 | std::rethrow_exception(wrappedRpcResponseFuture.exception_ptr()); |
471 | } else { |
472 | populateRemoteProfiledEvents( |
473 | profiledEvents, profilingConfig, event_lists); |
474 | auto rpcWithProfilingResp = std::make_unique<RpcWithProfilingResp>( |
475 | MessageType::RUN_WITH_PROFILING_RESP, |
476 | wrappedRpcResponseFuture.value().toCustomClass<Message>(), |
477 | profiledEvents, |
478 | profilingKeyId); |
479 | return withStorages(std::move(*rpcWithProfilingResp).toMessage()); |
480 | } |
481 | }), |
482 | c10::getCustomClassType<c10::intrusive_ptr<Message>>()); |
483 | |
484 | return responseFuture; |
485 | // Exiting the scope will disable the profiler on this thread with the |
486 | // options specified above. |
487 | } |
488 | } |
489 | |
490 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefBackward( |
491 | RpcCommandBase& rpc) const { |
492 | C10_THROW_ERROR(Error, "Python call not supported!" ); |
493 | } |
494 | |
495 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpc( |
496 | RpcCommandBase& rpc, |
497 | const MessageType& messageType, |
498 | std::vector<c10::Stream> streams) const { |
499 | // TODO: RpcCommandBase should have an abstract execute() method that we can |
500 | // call here instead of having another switch statement here. Even better we |
501 | // could have abstract classes RpcRequest and RpcResp which inherit from |
502 | // RpcCommandBase and RpcRequest declares the abstract method execute() that |
503 | // we can call here. RpcResponse could have an abstract method to convert it |
504 | // to a python object. |
505 | switch (messageType) { |
506 | case MessageType::SCRIPT_CALL: { |
507 | return processScriptCall(rpc, std::move(streams)); |
508 | } |
509 | case MessageType::PYTHON_CALL: { |
510 | return processPythonCall(rpc, std::move(streams)); |
511 | } |
512 | case MessageType::SCRIPT_REMOTE_CALL: { |
513 | return processScriptRemoteCall(rpc, std::move(streams)); |
514 | } |
515 | case MessageType::PYTHON_REMOTE_CALL: { |
516 | return processPythonRemoteCall(rpc, std::move(streams)); |
517 | } |
518 | case MessageType::SCRIPT_RREF_FETCH_CALL: { |
519 | return processScriptRRefFetchCall(rpc); |
520 | } |
521 | case MessageType::PYTHON_RREF_FETCH_CALL: { |
522 | return processPythonRRefFetchCall(rpc); |
523 | } |
524 | case MessageType::RREF_USER_DELETE: { |
525 | return processRRefUserDelete(rpc); |
526 | } |
527 | case MessageType::RREF_CHILD_ACCEPT: { |
528 | return processRRefChildAccept(rpc); |
529 | } |
530 | case MessageType::RREF_FORK_REQUEST: { |
531 | return processRRefForkRequest(rpc); |
532 | } |
533 | case MessageType::FORWARD_AUTOGRAD_REQ: { |
534 | return processForwardAutogradReq(rpc, std::move(streams)); |
535 | } |
536 | case MessageType::BACKWARD_AUTOGRAD_REQ: { |
537 | return processBackwardAutogradReq(rpc, std::move(streams)); |
538 | }; |
539 | case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: { |
540 | return processCleanupAutogradContextReq(rpc); |
541 | } |
542 | case MessageType::RUN_WITH_PROFILING_REQ: { |
543 | return processRunWithProfilingReq(rpc); |
544 | } |
545 | case MessageType::RREF_BACKWARD_REQ: { |
546 | return processRRefBackward(rpc); |
547 | } |
548 | default: { |
549 | TORCH_INTERNAL_ASSERT( |
550 | false, "Request type " , messageType, " not supported." ); |
551 | } |
552 | } |
553 | } |
554 | |
555 | c10::intrusive_ptr<Message> RequestCallbackNoPython::handleError( |
556 | const std::exception& e, |
557 | const MessageType messageType, |
558 | int64_t messageId) const { |
559 | LOG(ERROR) << "Received error while processing request type " << messageType |
560 | << ": " << e.what(); |
561 | // Adding node information to the error here since all processed RPC |
562 | // requests should be going through this function. |
563 | std::string errorMsg = c10::str( |
564 | "Error on Node " , |
565 | DistAutogradContainer::getInstance().getWorkerId(), |
566 | ": " , |
567 | e.what()); |
568 | return createExceptionResponse(errorMsg, messageId); |
569 | } |
570 | |
571 | bool RequestCallbackNoPython::cudaAvailable() const { |
572 | #ifdef USE_CUDA |
573 | return true; |
574 | #else |
575 | return false; |
576 | #endif |
577 | } |
578 | |
579 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::runJitOperator( |
580 | const jit::Operator& op, |
581 | std::vector<at::IValue>& stack, |
582 | std::vector<c10::Stream> streams) const { |
583 | c10::MultiStreamGuard guard(streams); |
584 | try { |
585 | op.getOperation()(stack); |
586 | } catch (const std::exception&) { |
587 | return asFuture(std::current_exception()); |
588 | } |
589 | TORCH_INTERNAL_ASSERT( |
590 | stack.size() == 1, |
591 | "Return value of a builtin operator or a TorchScript function should be " |
592 | "a single IValue, got a vector of size " , |
593 | stack.size()); |
594 | TypePtr type = stack.front().type(); |
595 | return asFuture(std::move(stack.front()), std::move(type)); |
596 | } |
597 | |
598 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture( |
599 | IValue value, |
600 | TypePtr type) const { |
601 | auto future = c10::make_intrusive<JitFuture>( |
602 | std::move(type), RpcAgent::getCurrentRpcAgent()->getDevices()); |
603 | future->markCompleted(std::move(value)); |
604 | return future; |
605 | } |
606 | |
607 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture( |
608 | c10::intrusive_ptr<Message> message) const { |
609 | auto future = c10::make_intrusive<JitFuture>( |
610 | at::getCustomClassType<c10::intrusive_ptr<Message>>(), |
611 | RpcAgent::getCurrentRpcAgent()->getDevices()); |
612 | std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages = |
613 | message->getStorages(); |
614 | future->markCompleted(std::move(message), std::move(storages)); |
615 | return future; |
616 | } |
617 | |
618 | c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture( |
619 | std::exception_ptr err) const { |
620 | auto future = c10::make_intrusive<JitFuture>( |
621 | at::NoneType::get(), RpcAgent::getCurrentRpcAgent()->getDevices()); |
622 | future->setError(err); |
623 | return future; |
624 | } |
625 | |
626 | } // namespace rpc |
627 | } // namespace distributed |
628 | } // namespace torch |
629 | |