1#include <ATen/ThreadLocalState.h>
2#include <fmt/format.h>
3#include <torch/csrc/autograd/record_function_ops.h>
4#include <torch/csrc/distributed/autograd/utils.h>
5#include <torch/csrc/distributed/rpc/message.h>
6#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
7#include <torch/csrc/distributed/rpc/rpc_agent.h>
8#include <torch/csrc/distributed/rpc/rref_proto.h>
9#include <torch/csrc/distributed/rpc/script_call.h>
10#include <torch/csrc/distributed/rpc/torchscript_functions.h>
11#include <torch/csrc/distributed/rpc/utils.h>
12
13namespace torch {
14namespace distributed {
15namespace rpc {
16
17c10::intrusive_ptr<JitFuture> rpcTorchscript(
18 const std::string& dstWorkerName,
19 const c10::QualifiedName& qualifiedName,
20 const c10::FunctionSchema& functionSchema,
21 std::vector<c10::IValue>& stack,
22 const float rpcTimeoutSeconds,
23 const bool isAsyncExecution) {
24 c10::intrusive_ptr<torch::autograd::profiler::PythonRecordFunction> record;
25 auto shouldProfile = torch::autograd::profiler::profilerEnabled() &&
26 !torch::distributed::rpc::RemoteProfilerManager::getInstance()
27 .isCurrentKeySet();
28 if (shouldProfile) {
29 auto rpcAsyncJitKey = fmt::format(
30 "rpc_async_jit#{}({} -> {})",
31 qualifiedName
32 .qualifiedName(), /* name of torchscript function being run */
33 RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_,
34 dstWorkerName);
35 record =
36 torch::autograd::profiler::record_function_enter_new(rpcAsyncJitKey);
37 auto& remoteProfilerManager =
38 torch::distributed::rpc::RemoteProfilerManager::getInstance();
39 remoteProfilerManager.setCurrentKey(rpcAsyncJitKey);
40 }
41 auto scriptCall = std::make_unique<ScriptCall>(
42 qualifiedName, std::move(stack), isAsyncExecution);
43 auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
44 auto jitFuture = autograd::sendMessageWithAutograd(
45 *rpcAgentPtr,
46 rpcAgentPtr->getWorkerInfo(dstWorkerName),
47 std::move(*scriptCall).toMessage(),
48 true /*forceGradRecording*/,
49 rpcTimeoutSeconds);
50
51 // Get function return type to construct JitFuture.
52 auto returns = functionSchema.returns();
53 // Script call only allows single IValue returned.
54 TORCH_INTERNAL_ASSERT(
55 returns.size() == 1,
56 "Return value of an annotated torchScript function should be a single "
57 "IValue.",
58 returns.size());
59 auto returnType = returns.at(0).type();
60
61 // Create a JIT future and pass it to futMessage's callback to set state
62 // of the JIT future.
63 auto futPtr = jitFuture->createInstance(returnType);
64 jitFuture->addCallback(at::wrapPropagateTLSState([futPtr](JitFuture& future) {
65 if (future.hasError()) {
66 futPtr->setError(future.exception_ptr());
67 } else {
68 futPtr->markCompleted(
69 deserializeRespToIValue(
70 *future.constValue().toCustomClass<Message>()),
71 future.storages());
72 }
73 }));
74 if (shouldProfile) {
75 auto profiledFutPtr =
76 torch::autograd::profiler::_call_end_callbacks_on_fut_new(
77 record, futPtr);
78 return profiledFutPtr;
79 }
80 return futPtr;
81}
82
83c10::intrusive_ptr<RRef> remoteTorchscript(
84 const std::string& dstWorkerName,
85 const c10::QualifiedName& qualifiedName,
86 const c10::FunctionSchema& functionSchema,
87 std::vector<c10::IValue>& stack,
88 const float rpcTimeoutSeconds,
89 const bool isAsyncExecution) {
90 auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
91 auto dstWorkerInfo = rpcAgentPtr->getWorkerInfo(dstWorkerName);
92 auto& ctx = RRefContext::getInstance();
93
94 // Get function return type to construct UserRRef.
95 auto returns = functionSchema.returns();
96 // Script call only allows single IValue returned.
97 TORCH_INTERNAL_ASSERT(
98 returns.size() == 1,
99 "Return value of an annotated torchScript function should be a single "
100 "IValue.",
101 returns.size());
102 auto returnType = returns.at(0).type();
103
104 if (ctx.getWorkerId() != dstWorkerInfo.id_) {
105 auto userRRefPtr = ctx.createUserRRef(dstWorkerInfo.id_, returnType);
106
107 auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
108 qualifiedName,
109 std::move(stack),
110 userRRefPtr->rrefId(),
111 userRRefPtr->forkId(),
112 isAsyncExecution);
113
114 auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd(
115 *rpcAgentPtr,
116 dstWorkerInfo,
117 std::move(*scriptRemoteCall).toMessage(),
118 true /*forceGradRecording*/,
119 rpcTimeoutSeconds /* timeout */);
120
121 userRRefPtr->registerOwnerCreationFuture(jitFuture);
122 ctx.addPendingUser(userRRefPtr->forkId(), userRRefPtr);
123 jitFuture->addCallback(at::wrapPropagateTLSState(
124 [forkId{userRRefPtr->forkId()}](JitFuture& future) {
125 callback::confirmPendingUser(future, forkId);
126 }));
127
128 return userRRefPtr;
129 } else {
130 auto ownerRRefPtr = ctx.createOwnerRRef(returnType);
131 // prevent this owner RRef from being deleted due to other forks
132 ctx.addSelfAsFork(ownerRRefPtr);
133
134 auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
135 qualifiedName,
136 std::move(stack),
137 ownerRRefPtr->rrefId(),
138 ownerRRefPtr->rrefId(),
139 isAsyncExecution);
140
141 auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd(
142 *rpcAgentPtr,
143 dstWorkerInfo,
144 std::move(*scriptRemoteCall).toMessage(),
145 true /*forceGradRecording*/,
146 rpcTimeoutSeconds /* timeout */);
147
148 ownerRRefPtr->registerOwnerCreationFuture(jitFuture);
149 jitFuture->addCallback(at::wrapPropagateTLSState(
150 [ownerRRefId = ownerRRefPtr->rrefId()](JitFuture& future) {
151 callback::finishCreatingOwnerRRef(future, ownerRRefId);
152 }));
153 return ownerRRefPtr;
154 }
155}
156
157} // namespace rpc
158} // namespace distributed
159} // namespace torch
160