1 | #include <ATen/ThreadLocalState.h> |
2 | #include <fmt/format.h> |
3 | #include <torch/csrc/autograd/record_function_ops.h> |
4 | #include <torch/csrc/distributed/autograd/utils.h> |
5 | #include <torch/csrc/distributed/rpc/message.h> |
6 | #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h> |
7 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
8 | #include <torch/csrc/distributed/rpc/rref_proto.h> |
9 | #include <torch/csrc/distributed/rpc/script_call.h> |
10 | #include <torch/csrc/distributed/rpc/torchscript_functions.h> |
11 | #include <torch/csrc/distributed/rpc/utils.h> |
12 | |
13 | namespace torch { |
14 | namespace distributed { |
15 | namespace rpc { |
16 | |
17 | c10::intrusive_ptr<JitFuture> rpcTorchscript( |
18 | const std::string& dstWorkerName, |
19 | const c10::QualifiedName& qualifiedName, |
20 | const c10::FunctionSchema& functionSchema, |
21 | std::vector<c10::IValue>& stack, |
22 | const float rpcTimeoutSeconds, |
23 | const bool isAsyncExecution) { |
24 | c10::intrusive_ptr<torch::autograd::profiler::PythonRecordFunction> record; |
25 | auto shouldProfile = torch::autograd::profiler::profilerEnabled() && |
26 | !torch::distributed::rpc::RemoteProfilerManager::getInstance() |
27 | .isCurrentKeySet(); |
28 | if (shouldProfile) { |
29 | auto rpcAsyncJitKey = fmt::format( |
30 | "rpc_async_jit#{}({} -> {})" , |
31 | qualifiedName |
32 | .qualifiedName(), /* name of torchscript function being run */ |
33 | RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_, |
34 | dstWorkerName); |
35 | record = |
36 | torch::autograd::profiler::record_function_enter_new(rpcAsyncJitKey); |
37 | auto& remoteProfilerManager = |
38 | torch::distributed::rpc::RemoteProfilerManager::getInstance(); |
39 | remoteProfilerManager.setCurrentKey(rpcAsyncJitKey); |
40 | } |
41 | auto scriptCall = std::make_unique<ScriptCall>( |
42 | qualifiedName, std::move(stack), isAsyncExecution); |
43 | auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent(); |
44 | auto jitFuture = autograd::sendMessageWithAutograd( |
45 | *rpcAgentPtr, |
46 | rpcAgentPtr->getWorkerInfo(dstWorkerName), |
47 | std::move(*scriptCall).toMessage(), |
48 | true /*forceGradRecording*/, |
49 | rpcTimeoutSeconds); |
50 | |
51 | // Get function return type to construct JitFuture. |
52 | auto returns = functionSchema.returns(); |
53 | // Script call only allows single IValue returned. |
54 | TORCH_INTERNAL_ASSERT( |
55 | returns.size() == 1, |
56 | "Return value of an annotated torchScript function should be a single " |
57 | "IValue." , |
58 | returns.size()); |
59 | auto returnType = returns.at(0).type(); |
60 | |
61 | // Create a JIT future and pass it to futMessage's callback to set state |
62 | // of the JIT future. |
63 | auto futPtr = jitFuture->createInstance(returnType); |
64 | jitFuture->addCallback(at::wrapPropagateTLSState([futPtr](JitFuture& future) { |
65 | if (future.hasError()) { |
66 | futPtr->setError(future.exception_ptr()); |
67 | } else { |
68 | futPtr->markCompleted( |
69 | deserializeRespToIValue( |
70 | *future.constValue().toCustomClass<Message>()), |
71 | future.storages()); |
72 | } |
73 | })); |
74 | if (shouldProfile) { |
75 | auto profiledFutPtr = |
76 | torch::autograd::profiler::_call_end_callbacks_on_fut_new( |
77 | record, futPtr); |
78 | return profiledFutPtr; |
79 | } |
80 | return futPtr; |
81 | } |
82 | |
83 | c10::intrusive_ptr<RRef> remoteTorchscript( |
84 | const std::string& dstWorkerName, |
85 | const c10::QualifiedName& qualifiedName, |
86 | const c10::FunctionSchema& functionSchema, |
87 | std::vector<c10::IValue>& stack, |
88 | const float rpcTimeoutSeconds, |
89 | const bool isAsyncExecution) { |
90 | auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent(); |
91 | auto dstWorkerInfo = rpcAgentPtr->getWorkerInfo(dstWorkerName); |
92 | auto& ctx = RRefContext::getInstance(); |
93 | |
94 | // Get function return type to construct UserRRef. |
95 | auto returns = functionSchema.returns(); |
96 | // Script call only allows single IValue returned. |
97 | TORCH_INTERNAL_ASSERT( |
98 | returns.size() == 1, |
99 | "Return value of an annotated torchScript function should be a single " |
100 | "IValue." , |
101 | returns.size()); |
102 | auto returnType = returns.at(0).type(); |
103 | |
104 | if (ctx.getWorkerId() != dstWorkerInfo.id_) { |
105 | auto userRRefPtr = ctx.createUserRRef(dstWorkerInfo.id_, returnType); |
106 | |
107 | auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>( |
108 | qualifiedName, |
109 | std::move(stack), |
110 | userRRefPtr->rrefId(), |
111 | userRRefPtr->forkId(), |
112 | isAsyncExecution); |
113 | |
114 | auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd( |
115 | *rpcAgentPtr, |
116 | dstWorkerInfo, |
117 | std::move(*scriptRemoteCall).toMessage(), |
118 | true /*forceGradRecording*/, |
119 | rpcTimeoutSeconds /* timeout */); |
120 | |
121 | userRRefPtr->registerOwnerCreationFuture(jitFuture); |
122 | ctx.addPendingUser(userRRefPtr->forkId(), userRRefPtr); |
123 | jitFuture->addCallback(at::wrapPropagateTLSState( |
124 | [forkId{userRRefPtr->forkId()}](JitFuture& future) { |
125 | callback::confirmPendingUser(future, forkId); |
126 | })); |
127 | |
128 | return userRRefPtr; |
129 | } else { |
130 | auto ownerRRefPtr = ctx.createOwnerRRef(returnType); |
131 | // prevent this owner RRef from being deleted due to other forks |
132 | ctx.addSelfAsFork(ownerRRefPtr); |
133 | |
134 | auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>( |
135 | qualifiedName, |
136 | std::move(stack), |
137 | ownerRRefPtr->rrefId(), |
138 | ownerRRefPtr->rrefId(), |
139 | isAsyncExecution); |
140 | |
141 | auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd( |
142 | *rpcAgentPtr, |
143 | dstWorkerInfo, |
144 | std::move(*scriptRemoteCall).toMessage(), |
145 | true /*forceGradRecording*/, |
146 | rpcTimeoutSeconds /* timeout */); |
147 | |
148 | ownerRRefPtr->registerOwnerCreationFuture(jitFuture); |
149 | jitFuture->addCallback(at::wrapPropagateTLSState( |
150 | [ownerRRefId = ownerRRefPtr->rrefId()](JitFuture& future) { |
151 | callback::finishCreatingOwnerRRef(future, ownerRRefId); |
152 | })); |
153 | return ownerRRefPtr; |
154 | } |
155 | } |
156 | |
157 | } // namespace rpc |
158 | } // namespace distributed |
159 | } // namespace torch |
160 | |