1#include <torch/csrc/distributed/rpc/request_callback_no_python.h>
2
3#include <c10/core/StreamGuard.h>
4#include <torch/csrc/distributed/autograd/context/container.h>
5#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
6#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
7#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
8#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
9#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
10#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
11#include <torch/csrc/distributed/autograd/utils.h>
12#include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
13#include <torch/csrc/distributed/rpc/rpc_agent.h>
14#include <torch/csrc/distributed/rpc/rref_context.h>
15#include <torch/csrc/distributed/rpc/rref_proto.h>
16#include <torch/csrc/distributed/rpc/script_resp.h>
17#include <torch/csrc/distributed/rpc/utils.h>
18
19namespace torch {
20namespace distributed {
21namespace rpc {
22
23using namespace torch::distributed::autograd;
24using namespace torch::autograd::profiler;
25
26// When request message has autograd info, processMessage() will set up valid
27// current context id properly. This struct is used to clean up current context
28// id after processMessage() is done.
29struct DistAutogradContextGuard {
30 explicit DistAutogradContextGuard(int64_t ctxId) {
31 auto& container = DistAutogradContainer::getInstance();
32 prevCtxId_ = container.currentContextId();
33 container.forceCurrentContextId(ctxId);
34 }
35 ~DistAutogradContextGuard() {
36 auto& container = DistAutogradContainer::getInstance();
37 container.forceCurrentContextId(prevCtxId_);
38 }
39
40 int64_t prevCtxId_;
41};
42
43std::unique_ptr<RpcCommandBase> RequestCallbackNoPython::
44 deserializePythonRpcCommand(
45 std::unique_ptr<RpcCommandBase> rpc,
46 const MessageType& messageType) const {
47 TORCH_CHECK(
48 messageType != MessageType::PYTHON_CALL &&
49 messageType != MessageType::PYTHON_REMOTE_CALL,
50 "Python calls are not supported!");
51 return rpc;
52}
53
54c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processMessage(
55 Message& request,
56 std::vector<c10::Stream> streams) const {
57 // We need two futures here because it could pause twice when processing a
58 // RPC message:
59 // 1) waiting for all RRefs in the arguments to become confirmed;
60 // 2) waiting for processRpc to finish.
61 auto& rrefContext = RRefContext::getInstance();
62 try {
63 rrefContext.recordThreadLocalPendingRRefs();
64 // Deserialize PythonUDF here to trigger RRef unpickling
65 std::unique_ptr<RpcCommandBase> rpc = deserializePythonRpcCommand(
66 deserializeRequest(request), request.type());
67 auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs();
68
69 auto retFuture = rrefsReadyFuture->thenAsync(
70 [this,
71 // std::function must be copyable, hence hae to cast the unique_ptr to
72 // a shared_ptr here.
73 rpc = (std::shared_ptr<RpcCommandBase>)std::move(rpc),
74 messageType = request.type(),
75 streams = std::move(streams)](JitFuture& /* unused */) mutable {
76 // The cost of pre-request check is minimal thanks to
77 // std::shared_lock. The cost is in magnitude
78 // of 10us.
79 auto serverProcessGlobalProfilerStateStackEntryPtr =
80 profiler::processglobal::StateStackEntry::current();
81 // If server global profiler is enabled, we futher pay the
82 // cost of thread local profiler state initialization.
83 if (serverProcessGlobalProfilerStateStackEntryPtr) {
84 // Initialize thread-local profiler state from process-global
85 // profiler state.
86 enableProfilerLegacy(
87 serverProcessGlobalProfilerStateStackEntryPtr->statePtr()
88 ->config());
89 }
90
91 auto retFuture =
92 processRpcWithErrors(*rpc, messageType, std::move(streams));
93
94 // Response message has been sent at this moment, this post-response
95 // work doesn't affect RPC trip time.
96 if (serverProcessGlobalProfilerStateStackEntryPtr) {
97 // Restore thread-local profiler state.
98 thread_event_lists event_lists = disableProfilerLegacy();
99 // Put thread_local event_lists into the process-global profiler
100 // state.
101 profiler::processglobal::pushResultRecursive(
102 serverProcessGlobalProfilerStateStackEntryPtr, event_lists);
103 }
104
105 return retFuture;
106 },
107 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
108
109 auto retFutureWithMessageId = retFuture->then(
110 [id = request.id()](JitFuture& future) {
111 c10::intrusive_ptr<Message> message =
112 future.value().toCustomClass<Message>();
113 message->setId(id);
114 return withStorages(message);
115 },
116 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
117
118 return retFutureWithMessageId;
119 } catch (std::exception& e) {
120 rrefContext.clearRecordedPendingRRefsOnError();
121 return asFuture(handleError(e, request.type(), request.id()));
122 }
123}
124
125c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpcWithErrors(
126 RpcCommandBase& rpc,
127 const MessageType& messageType,
128 std::vector<c10::Stream> streams) const {
129 try {
130 return processRpc(rpc, messageType, std::move(streams));
131 } catch (std::exception& e) {
132 // Pass a dummy message ID since it will be overwritten anyways.
133 return asFuture(handleError(e, messageType, -1));
134 }
135}
136
137c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptCall(
138 RpcCommandBase& rpc,
139 std::vector<c10::Stream> streams) const {
140 auto& scriptCall = static_cast<ScriptCall&>(rpc);
141
142 TORCH_CHECK(
143 scriptCall.hasOp(), "Only supports the case where ScriptCall has an op");
144 auto future = runJitOperator(
145 *scriptCall.op(), scriptCall.stackRef(), std::move(streams));
146
147 return future->then(
148 [](JitFuture& future) {
149 return withStorages(ScriptResp(future.value()).toMessage());
150 },
151 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
152}
153
154c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonCall(
155 RpcCommandBase& rpc,
156 std::vector<c10::Stream> /* unused */) const {
157 C10_THROW_ERROR(Error, "Python call not supported!");
158}
159
160c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonRemoteCall(
161 RpcCommandBase& rpc,
162 std::vector<c10::Stream> /* unused */) const {
163 C10_THROW_ERROR(Error, "Python call not supported!");
164}
165
166c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::assignOwnerRRef(
167 const RRefId& rrefId,
168 const RRefId& forkId,
169 c10::intrusive_ptr<JitFuture> valueFuture) const {
170 auto& ctx = RRefContext::getInstance();
171
172 c10::intrusive_ptr<OwnerRRef> ownerRRef;
173 if (rrefId == forkId) {
174 // Creating an owner RRef on self, should already exist in owners map
175 ownerRRef =
176 fromRRefInterface(ctx.getOwnerRRef(rrefId, /* forceCreated */ true)
177 ->constValue()
178 .toRRef());
179 } else {
180 ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, valueFuture->elementType());
181 // Caller is a user and callee is the owner, add fork
182 //
183 // NB: rrefId == forkId is true if and only if calling remote to self.
184 // In that case both the caller and the callee will access the
185 // OwnerRRef. Hence, on the callee side (here), it should not call
186 // addForkOfOwner as it is not a fork. To allow callee to distinguish
187 // when this request is sent to self, the caller will set forkId using
188 // rrefId (OwnerRRef does not have a forkId anyway).
189 ctx.addForkOfOwner(rrefId, forkId);
190 }
191
192 return valueFuture->then(
193 [ownerRRef, rrefId, forkId](JitFuture& future) {
194 if (future.hasError()) {
195 ownerRRef->setError(future.exception_ptr());
196 } else {
197 ownerRRef->setValue(future.value());
198 }
199 return withStorages(RemoteRet(rrefId, forkId).toMessage());
200 },
201 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
202}
203
204c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptRemoteCall(
205 RpcCommandBase& rpc,
206 std::vector<c10::Stream> streams) const {
207 auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
208
209 TORCH_CHECK(
210 scriptRemoteCall.hasOp(), "ScriptRemoteCall needs to have an op!");
211 auto future = runJitOperator(
212 *scriptRemoteCall.op(), scriptRemoteCall.stackRef(), std::move(streams));
213
214 return assignOwnerRRef(
215 scriptRemoteCall.retRRefId(),
216 scriptRemoteCall.retForkId(),
217 std::move(future));
218}
219
220c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::retrieveOwnerRRef(
221 const RRefId& rrefId) const {
222 auto& ctx = RRefContext::getInstance();
223
224 auto rrefFuture = ctx.getOwnerRRef(rrefId);
225
226 at::TypePtr type = rrefFuture->elementType();
227 TORCH_INTERNAL_ASSERT(type->kind() == at::RRefType::Kind);
228 return rrefFuture->thenAsync(
229 [](JitFuture& rrefFuture) {
230 c10::intrusive_ptr<OwnerRRef> rref =
231 fromRRefInterface(rrefFuture.value().toRRef());
232 return rref->getFuture();
233 },
234 type->cast<at::RRefType>()->getElementType());
235}
236
237c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
238 processScriptRRefFetchCall(RpcCommandBase& rpc) const {
239 auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
240
241 auto future = retrieveOwnerRRef(srf.rrefId());
242
243 return future->then(
244 [](JitFuture& future) {
245 return withStorages(ScriptRRefFetchRet({future.value()}).toMessage());
246 },
247 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
248}
249
250c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
251 processPythonRRefFetchCall(RpcCommandBase& rpc) const {
252 C10_THROW_ERROR(Error, "Python call not supported!");
253}
254
255c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefUserDelete(
256 RpcCommandBase& rpc) const {
257 auto& rud = static_cast<RRefUserDelete&>(rpc);
258 auto& ctx = RRefContext::getInstance();
259 auto deletedRRef = ctx.delForkOfOwner(rud.rrefId(), rud.forkId());
260 handleRRefDelete(deletedRRef);
261 return asFuture(RRefAck().toMessage());
262}
263
264void RequestCallbackNoPython::handleRRefDelete(
265 c10::intrusive_ptr<RRef>& rref) const {
266 TORCH_CHECK(!rref->isPyObj(), "RRefs with python objects not supported!");
267}
268
269c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefChildAccept(
270 RpcCommandBase& rpc) const {
271 auto& rca = static_cast<RRefChildAccept&>(rpc);
272 auto& ctx = RRefContext::getInstance();
273 ctx.delPendingChild(rca.forkId());
274 return asFuture(RRefAck().toMessage());
275}
276
277c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefForkRequest(
278 RpcCommandBase& rpc) const {
279 auto& rfr = static_cast<RRefForkRequest&>(rpc);
280 auto& ctx = RRefContext::getInstance();
281 ctx.addForkOfOwnerIfNotPresent(rfr.rrefId(), rfr.forkId());
282 return asFuture(RRefAck().toMessage());
283}
284
285c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
286 processForwardAutogradReq(
287 RpcCommandBase& rpc,
288 std::vector<c10::Stream> streams) const {
289 auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
290
291 // Need to reverse the device map for the backward pass of distributed
292 // autograd.
293 DeviceMap reverseDeviceMap;
294 for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
295 reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
296 }
297
298 // Attach 'recv' autograd function.
299 auto autogradContext = addRecvRpcBackward(
300 rpcWithAutograd.autogradMetadata(),
301 rpcWithAutograd.tensors(),
302 rpcWithAutograd.fromWorkerId(),
303 reverseDeviceMap);
304 // For this recv thread on server side, before processRpc(),
305 // set current_context_id_ to be context_id passed from client.
306 // In this way, if there is nested rpc call in python rpc call, original
307 // context_id from client can be passed in the chain calls.
308 TORCH_INTERNAL_ASSERT(
309 autogradContext != nullptr,
310 "autogradContext is nullptr, FORWARD_AUTOGRAD_REQ should always get "
311 "or create valid autogradContext in addRecvRpcBackward.");
312
313 DistAutogradContextGuard ctxGuard(autogradContext->contextId());
314
315 // Process the original RPC.
316 auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
317 // Kick off processing for the nested RPC command.
318 // wrappedRpcResponseFuture will be a Future<T> to the result.
319 auto wrappedRpcResponseFuture = processRpc(
320 rpcWithAutograd.wrappedRpc(), wrappedMessageType, std::move(streams));
321
322 auto fromWorkerId = rpcWithAutograd.fromWorkerId();
323 // The original future needs to be marked as completed when the wrapped
324 // one completes, with the autograd context information wrapped.
325 auto responseFuture = wrappedRpcResponseFuture->then(
326 [fromWorkerId, ctxId = autogradContext->contextId()](
327 JitFuture& wrappedRpcResponseFuture) {
328 // As this callback can be invoked by a different thread, we have to
329 // make sure that the thread_local states in the previous thread is
330 // correctly propagated.
331 // NB: The execution of TorchScript functions can also run on a
332 // different thread, which is addressed by
333 // https://github.com/pytorch/pytorch/pull/36395
334 // NB: when adding async UDF support, we should also propagate
335 // thread_local states there.
336 // TODO: Land on a general solution for RPC ThreadLocalState. See
337 // https://github.com/pytorch/pytorch/issues/38510
338 DistAutogradContextGuard cbCtxGuard(ctxId);
339
340 if (wrappedRpcResponseFuture.hasError()) {
341 // Propagate error to responseFuture if we had one.
342 std::rethrow_exception(wrappedRpcResponseFuture.exception_ptr());
343 } else {
344 auto msg = getMessageWithAutograd(
345 fromWorkerId,
346 wrappedRpcResponseFuture.value().toCustomClass<Message>(),
347 MessageType::FORWARD_AUTOGRAD_RESP);
348 return withStorages(std::move(msg));
349 }
350 },
351 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
352
353 return responseFuture;
354}
355
356c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
357 processBackwardAutogradReq(
358 RpcCommandBase& rpc,
359 std::vector<c10::Stream> streams) const {
360 c10::MultiStreamGuard guard(streams);
361 auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
362 const auto& autogradMetadata = gradientsCall.getAutogradMetadata();
363
364 // Retrieve the appropriate autograd context.
365 auto autogradContext = DistAutogradContainer::getInstance().retrieveContext(
366 autogradMetadata.autogradContextId);
367
368 // Lookup the appropriate 'send' function to enqueue.
369 std::shared_ptr<SendRpcBackward> sendFunction =
370 autogradContext->retrieveSendFunction(autogradMetadata.autogradMessageId);
371
372 // Attach the gradients to the send function.
373 sendFunction->setGrads(gradientsCall.getGrads());
374
375 // Now execute the autograd graph using the "distributed engine."
376 auto execFuture = DistEngine::getInstance().executeSendFunctionAsync(
377 autogradContext, sendFunction, gradientsCall.retainGraph());
378
379 // Our response is satisfied when the rpcs come back.
380 return execFuture->then(
381 [](JitFuture& execFuture) {
382 if (execFuture.hasError()) {
383 std::rethrow_exception(execFuture.exception_ptr());
384 } else {
385 return withStorages(PropagateGradientsResp().toMessage());
386 }
387 },
388 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
389}
390
391c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
392 processCleanupAutogradContextReq(RpcCommandBase& rpc) const {
393 auto& cleanupContextReq = static_cast<CleanupAutogradContextReq&>(rpc);
394 auto cleanupContextId = cleanupContextReq.getContextId();
395 // release the context if it still exists on this thread. We need to
396 // check if it exists since it may have been deleted by an in-flight
397 // RPC. This can create nested RPCs if there are other nodes that get
398 // notified to clean up their context.
399 DistAutogradContainer::getInstance().releaseContextIfPresent(
400 cleanupContextId);
401 return asFuture(CleanupAutogradContextResp().toMessage());
402}
403
404c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
405 processRunWithProfilingReq(RpcCommandBase& rpc) const {
406 auto& rpcWithProfilingReq = static_cast<RpcWithProfilingReq&>(rpc);
407 auto wrappedMsgType = rpcWithProfilingReq.wrappedMessageType();
408 auto profilingConfig = rpcWithProfilingReq.getProfilingConfig();
409
410 if (profilingConfig.state == ProfilerState::KINETO ||
411 profilingConfig.state == ProfilerState::KINETO_GPU_FALLBACK) {
412 profilingConfig = ProfilerConfig(
413 ProfilerState::CPU,
414 profilingConfig.report_input_shapes,
415 profilingConfig.profile_memory);
416 }
417
418 // If requested with CUDA from caller but CUDA is not available on this
419 // machine, fallback to CPU and log a warning instead of crashing.
420 if (profilingConfig.state == ProfilerState::CUDA && !this->cudaAvailable()) {
421 profilingConfig = ProfilerConfig(
422 ProfilerState::CPU,
423 profilingConfig.report_input_shapes,
424 profilingConfig.profile_memory);
425
426 LOG(WARNING) << "Profiler was requested to be enabled with CUDA on this "
427 "node, but CUDA is not available. "
428 << "Falling back to CPU profiling only.";
429 }
430 TORCH_INTERNAL_ASSERT(
431 profilingConfig.state != ProfilerState::CUDA || this->cudaAvailable(),
432 "Profiler state set to CUDA but CUDA not available.");
433 const auto profilingKeyId = rpcWithProfilingReq.getProfilingId();
434 // Enable the profiler with the config from the sender.
435 // When enabling on the main thread, ensure profiler states are cleaned
436 // up, but defer consolidation of all profiled events to the continuation
437 // below.
438 ProfilerDisableOptions requestThreadOptions(
439 true /* cleanup TLS state */, false /* consolidate events */);
440 {
441 TLSLegacyProfilerGuard g(
442 profilingConfig, c10::nullopt, requestThreadOptions);
443 TORCH_INTERNAL_ASSERT(
444 profilerEnabled(), "Expected profiler to be enabled!");
445 // Kick off processing for nested work and get Future<T> result in
446 // wrappedRpcResponseFuture
447 auto wrappedRpcResponseFuture = processRpc(
448 rpcWithProfilingReq.wrappedRpc(),
449 wrappedMsgType,
450 {}); // TODO: https://github.com/pytorch/pytorch/issues/55757
451
452 auto responseFuture = wrappedRpcResponseFuture->then(
453 at::wrapPropagateTLSState([profilingKeyId, profilingConfig](
454 JitFuture& wrappedRpcResponseFuture) {
455 std::vector<LegacyEvent> profiledEvents;
456 // Defer consolidation of profiler events until async work has
457 // completed (such as async UDF)
458
459 TORCH_INTERNAL_ASSERT(
460 profilerEnabled(), "Expected profiler to be enabled!");
461
462 // On continuation thread, don't clean up profiler states, since
463 // they will be cleaned up by main thread, and consolidate all
464 // events so we obtain asynchronously run events.
465 ProfilerDisableOptions opts(false, true);
466 auto event_lists = disableProfilerLegacy(opts);
467 if (wrappedRpcResponseFuture.hasError()) {
468 // Propagate error
469 // No need to propagate remote events in the case of an error.
470 std::rethrow_exception(wrappedRpcResponseFuture.exception_ptr());
471 } else {
472 populateRemoteProfiledEvents(
473 profiledEvents, profilingConfig, event_lists);
474 auto rpcWithProfilingResp = std::make_unique<RpcWithProfilingResp>(
475 MessageType::RUN_WITH_PROFILING_RESP,
476 wrappedRpcResponseFuture.value().toCustomClass<Message>(),
477 profiledEvents,
478 profilingKeyId);
479 return withStorages(std::move(*rpcWithProfilingResp).toMessage());
480 }
481 }),
482 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
483
484 return responseFuture;
485 // Exiting the scope will disable the profiler on this thread with the
486 // options specified above.
487 }
488}
489
490c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefBackward(
491 RpcCommandBase& rpc) const {
492 C10_THROW_ERROR(Error, "Python call not supported!");
493}
494
495c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpc(
496 RpcCommandBase& rpc,
497 const MessageType& messageType,
498 std::vector<c10::Stream> streams) const {
499 // TODO: RpcCommandBase should have an abstract execute() method that we can
500 // call here instead of having another switch statement here. Even better we
501 // could have abstract classes RpcRequest and RpcResp which inherit from
502 // RpcCommandBase and RpcRequest declares the abstract method execute() that
503 // we can call here. RpcResponse could have an abstract method to convert it
504 // to a python object.
505 switch (messageType) {
506 case MessageType::SCRIPT_CALL: {
507 return processScriptCall(rpc, std::move(streams));
508 }
509 case MessageType::PYTHON_CALL: {
510 return processPythonCall(rpc, std::move(streams));
511 }
512 case MessageType::SCRIPT_REMOTE_CALL: {
513 return processScriptRemoteCall(rpc, std::move(streams));
514 }
515 case MessageType::PYTHON_REMOTE_CALL: {
516 return processPythonRemoteCall(rpc, std::move(streams));
517 }
518 case MessageType::SCRIPT_RREF_FETCH_CALL: {
519 return processScriptRRefFetchCall(rpc);
520 }
521 case MessageType::PYTHON_RREF_FETCH_CALL: {
522 return processPythonRRefFetchCall(rpc);
523 }
524 case MessageType::RREF_USER_DELETE: {
525 return processRRefUserDelete(rpc);
526 }
527 case MessageType::RREF_CHILD_ACCEPT: {
528 return processRRefChildAccept(rpc);
529 }
530 case MessageType::RREF_FORK_REQUEST: {
531 return processRRefForkRequest(rpc);
532 }
533 case MessageType::FORWARD_AUTOGRAD_REQ: {
534 return processForwardAutogradReq(rpc, std::move(streams));
535 }
536 case MessageType::BACKWARD_AUTOGRAD_REQ: {
537 return processBackwardAutogradReq(rpc, std::move(streams));
538 };
539 case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
540 return processCleanupAutogradContextReq(rpc);
541 }
542 case MessageType::RUN_WITH_PROFILING_REQ: {
543 return processRunWithProfilingReq(rpc);
544 }
545 case MessageType::RREF_BACKWARD_REQ: {
546 return processRRefBackward(rpc);
547 }
548 default: {
549 TORCH_INTERNAL_ASSERT(
550 false, "Request type ", messageType, " not supported.");
551 }
552 }
553}
554
555c10::intrusive_ptr<Message> RequestCallbackNoPython::handleError(
556 const std::exception& e,
557 const MessageType messageType,
558 int64_t messageId) const {
559 LOG(ERROR) << "Received error while processing request type " << messageType
560 << ": " << e.what();
561 // Adding node information to the error here since all processed RPC
562 // requests should be going through this function.
563 std::string errorMsg = c10::str(
564 "Error on Node ",
565 DistAutogradContainer::getInstance().getWorkerId(),
566 ": ",
567 e.what());
568 return createExceptionResponse(errorMsg, messageId);
569}
570
571bool RequestCallbackNoPython::cudaAvailable() const {
572#ifdef USE_CUDA
573 return true;
574#else
575 return false;
576#endif
577}
578
579c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::runJitOperator(
580 const jit::Operator& op,
581 std::vector<at::IValue>& stack,
582 std::vector<c10::Stream> streams) const {
583 c10::MultiStreamGuard guard(streams);
584 try {
585 op.getOperation()(stack);
586 } catch (const std::exception&) {
587 return asFuture(std::current_exception());
588 }
589 TORCH_INTERNAL_ASSERT(
590 stack.size() == 1,
591 "Return value of a builtin operator or a TorchScript function should be "
592 "a single IValue, got a vector of size ",
593 stack.size());
594 TypePtr type = stack.front().type();
595 return asFuture(std::move(stack.front()), std::move(type));
596}
597
598c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture(
599 IValue value,
600 TypePtr type) const {
601 auto future = c10::make_intrusive<JitFuture>(
602 std::move(type), RpcAgent::getCurrentRpcAgent()->getDevices());
603 future->markCompleted(std::move(value));
604 return future;
605}
606
607c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture(
608 c10::intrusive_ptr<Message> message) const {
609 auto future = c10::make_intrusive<JitFuture>(
610 at::getCustomClassType<c10::intrusive_ptr<Message>>(),
611 RpcAgent::getCurrentRpcAgent()->getDevices());
612 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages =
613 message->getStorages();
614 future->markCompleted(std::move(message), std::move(storages));
615 return future;
616}
617
618c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture(
619 std::exception_ptr err) const {
620 auto future = c10::make_intrusive<JitFuture>(
621 at::NoneType::get(), RpcAgent::getCurrentRpcAgent()->getDevices());
622 future->setError(err);
623 return future;
624}
625
626} // namespace rpc
627} // namespace distributed
628} // namespace torch
629