1 | #include <torch/csrc/distributed/rpc/request_callback_impl.h> |
2 | |
3 | #include <c10/util/C++17.h> |
4 | #include <torch/csrc/autograd/profiler.h> |
5 | #include <torch/csrc/distributed/autograd/context/container.h> |
6 | #include <torch/csrc/distributed/autograd/context/context.h> |
7 | #include <torch/csrc/distributed/autograd/engine/dist_engine.h> |
8 | #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h> |
9 | #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h> |
10 | #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h> |
11 | #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h> |
12 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h> |
13 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h> |
14 | #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h> |
15 | #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h> |
16 | #include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h> |
17 | #include <torch/csrc/distributed/autograd/utils.h> |
18 | #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h> |
19 | #include <torch/csrc/distributed/rpc/py_rref.h> |
20 | #include <torch/csrc/distributed/rpc/python_call.h> |
21 | #include <torch/csrc/distributed/rpc/python_remote_call.h> |
22 | #include <torch/csrc/distributed/rpc/python_resp.h> |
23 | #include <torch/csrc/distributed/rpc/python_rpc_handler.h> |
24 | #include <torch/csrc/distributed/rpc/rref_context.h> |
25 | #include <torch/csrc/distributed/rpc/rref_impl.h> |
26 | #include <torch/csrc/distributed/rpc/rref_proto.h> |
27 | #include <torch/csrc/distributed/rpc/script_call.h> |
28 | #include <torch/csrc/distributed/rpc/script_remote_call.h> |
29 | #include <torch/csrc/distributed/rpc/script_resp.h> |
30 | #include <torch/csrc/distributed/rpc/unpickled_python_call.h> |
31 | #include <torch/csrc/distributed/rpc/unpickled_python_remote_call.h> |
32 | #include <torch/csrc/distributed/rpc/utils.h> |
33 | #include <torch/csrc/jit/python/python_ivalue.h> |
34 | |
35 | namespace torch { |
36 | namespace distributed { |
37 | namespace rpc { |
38 | |
39 | using namespace torch::distributed::autograd; |
40 | |
41 | namespace { |
42 | |
43 | std::unique_ptr<RpcCommandBase> deserializePythonRpcCommandReference( |
44 | RpcCommandBase& rpc, |
45 | const MessageType& messageType) { |
46 | switch (messageType) { |
47 | case MessageType::PYTHON_CALL: { |
48 | auto& pc = static_cast<PythonCall&>(rpc); |
49 | return std::make_unique<UnpickledPythonCall>( |
50 | pc.serializedPyObj(), pc.isAsyncExecution()); |
51 | } |
52 | case MessageType::PYTHON_REMOTE_CALL: { |
53 | auto& prc = static_cast<PythonRemoteCall&>(rpc); |
54 | return std::make_unique<UnpickledPythonRemoteCall>( |
55 | prc.serializedPyObj(), |
56 | prc.retRRefId(), |
57 | prc.retForkId(), |
58 | prc.isAsyncExecution()); |
59 | } |
60 | case MessageType::FORWARD_AUTOGRAD_REQ: { |
61 | // Deserialize the wrapped RPC if it contains Python UDF |
62 | auto& rwa = static_cast<RpcWithAutograd&>(rpc); |
63 | auto& wrappedRpc = rwa.wrappedRpc(); |
64 | auto pythonRpc = deserializePythonRpcCommandReference( |
65 | wrappedRpc, rwa.wrappedMessageType()); |
66 | if (pythonRpc) { |
67 | rwa.setWrappedRpc(std::move(pythonRpc)); |
68 | } |
69 | return nullptr; |
70 | } |
71 | case MessageType::RUN_WITH_PROFILING_REQ: { |
72 | // Deserialize wrapped RPC if it contains python call |
73 | auto& rpcWithProfilingReq = static_cast<RpcWithProfilingReq&>(rpc); |
74 | auto& wrappedRpc = rpcWithProfilingReq.wrappedRpc(); |
75 | auto pythonRpc = deserializePythonRpcCommandReference( |
76 | wrappedRpc, rpcWithProfilingReq.wrappedMessageType()); |
77 | if (pythonRpc) { |
78 | rpcWithProfilingReq.setWrappedRpc(std::move(pythonRpc)); |
79 | } |
80 | return nullptr; |
81 | } |
82 | default: { |
83 | return nullptr; |
84 | } |
85 | } |
86 | } |
87 | |
88 | SerializedPyObj serializePyObject(IValue value) { |
89 | auto& pythonRpcHandler = PythonRpcHandler::getInstance(); |
90 | // Need this GIL to guard jit::toPyObj and destruct its returned |
91 | // py::object |
92 | py::gil_scoped_acquire acquire; |
93 | try { |
94 | return pythonRpcHandler.serialize(jit::toPyObject(value)); |
95 | } catch (py::error_already_set& e) { |
96 | // py::error_already_set requires GIL to destruct, take special care. |
97 | auto err = std::runtime_error(e.what()); |
98 | e.restore(); |
99 | PyErr_Clear(); |
100 | throw err; |
101 | } |
102 | } |
103 | |
104 | } // anonymous namespace |
105 | |
106 | c10::intrusive_ptr<JitFuture> RequestCallbackImpl::runPythonFunction( |
107 | const py::object& function, |
108 | std::vector<c10::Stream> streams, |
109 | bool isAsyncExecution) const { |
110 | c10::MultiStreamGuard guard(streams); |
111 | auto& pythonRpcHandler = PythonRpcHandler::getInstance(); |
112 | py::gil_scoped_acquire acquire; |
113 | |
114 | py::object result; |
115 | try { |
116 | result = pythonRpcHandler.runPythonUdf(function); |
117 | } catch (py::error_already_set& e) { |
118 | // py::error_already_set requires GIL to destruct, take special care. |
119 | auto future = |
120 | asFuture(std::make_exception_ptr(std::runtime_error(e.what()))); |
121 | e.restore(); |
122 | PyErr_Clear(); |
123 | return future; |
124 | } catch (std::exception& e) { |
125 | return asFuture(std::current_exception()); |
126 | } |
127 | |
128 | // After sync exection or failed async execution return the value as-is. |
129 | if (pythonRpcHandler.isRemoteException(result) || !isAsyncExecution) { |
130 | return asFuture( |
131 | c10::ivalue::ConcretePyObjectHolder::create(result), |
132 | at::PyObjectType::get()); |
133 | } |
134 | |
135 | try { |
136 | return result.cast<jit::PythonFutureWrapper&>().fut; |
137 | } catch (const py::cast_error& e) { |
138 | auto type = result.get_type(); |
139 | auto errMsg = c10::str( |
140 | e.what(), |
141 | ". Functions decorated with @rpc.async_function must return a " |
142 | "torch.futures.Future object, but got " , |
143 | type.attr("__module__" ).cast<std::string>(), |
144 | "." , |
145 | type.attr("__qualname__" ).cast<std::string>()); |
146 | return asFuture(std::make_exception_ptr(std::runtime_error(errMsg))); |
147 | } |
148 | } |
149 | |
150 | std::unique_ptr<RpcCommandBase> RequestCallbackImpl:: |
151 | deserializePythonRpcCommand( |
152 | std::unique_ptr<RpcCommandBase> rpc, |
153 | const MessageType& messageType) const { |
154 | auto pythonRpc = deserializePythonRpcCommandReference(*rpc, messageType); |
155 | return pythonRpc ? std::move(pythonRpc) : std::move(rpc); |
156 | } |
157 | |
158 | c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptCall( |
159 | RpcCommandBase& rpc, |
160 | std::vector<c10::Stream> streams) const { |
161 | auto& scriptCall = static_cast<ScriptCall&>(rpc); |
162 | |
163 | c10::intrusive_ptr<JitFuture> future; |
164 | if (scriptCall.hasOp()) { |
165 | future = runJitOperator( |
166 | *scriptCall.op(), scriptCall.stackRef(), std::move(streams)); |
167 | } else { |
168 | future = runJitFunction( |
169 | scriptCall.qualifiedName(), |
170 | scriptCall.stackRef(), |
171 | std::move(streams), |
172 | scriptCall.isAsyncExecution()); |
173 | } |
174 | |
175 | return future->then( |
176 | [](JitFuture& jitFuture) { |
177 | return withStorages(ScriptResp(jitFuture.value()).toMessage()); |
178 | }, |
179 | c10::getCustomClassType<c10::intrusive_ptr<Message>>()); |
180 | } |
181 | |
182 | c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonCall( |
183 | RpcCommandBase& rpc, |
184 | std::vector<c10::Stream> streams) const { |
185 | auto& upc = static_cast<UnpickledPythonCall&>(rpc); |
186 | auto future = runPythonFunction( |
187 | upc.pythonUdf(), std::move(streams), upc.isAsyncExecution()); |
188 | |
189 | return future->then( |
190 | [](JitFuture& future) { |
191 | return withStorages( |
192 | PythonResp(serializePyObject(future.value())).toMessage()); |
193 | }, |
194 | c10::getCustomClassType<c10::intrusive_ptr<Message>>()); |
195 | } |
196 | |
197 | c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptRemoteCall( |
198 | RpcCommandBase& rpc, |
199 | std::vector<c10::Stream> streams) const { |
200 | auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc); |
201 | |
202 | c10::intrusive_ptr<JitFuture> future; |
203 | if (scriptRemoteCall.hasOp()) { |
204 | future = runJitOperator( |
205 | *scriptRemoteCall.op(), |
206 | scriptRemoteCall.stackRef(), |
207 | std::move(streams)); |
208 | } else { |
209 | future = runJitFunction( |
210 | scriptRemoteCall.qualifiedName(), |
211 | scriptRemoteCall.stackRef(), |
212 | std::move(streams), |
213 | scriptRemoteCall.isAsyncExecution()); |
214 | } |
215 | |
216 | return assignOwnerRRef( |
217 | scriptRemoteCall.retRRefId(), |
218 | scriptRemoteCall.retForkId(), |
219 | std::move(future)); |
220 | } |
221 | |
222 | c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonRemoteCall( |
223 | RpcCommandBase& rpc, |
224 | std::vector<c10::Stream> streams) const { |
225 | auto& uprc = static_cast<UnpickledPythonRemoteCall&>(rpc); |
226 | auto future = runPythonFunction( |
227 | uprc.pythonUdf(), std::move(streams), uprc.isAsyncExecution()); |
228 | |
229 | return assignOwnerRRef(uprc.rrefId(), uprc.forkId(), std::move(future)); |
230 | } |
231 | |
232 | c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonRRefFetchCall( |
233 | RpcCommandBase& rpc) const { |
234 | auto& prf = static_cast<PythonRRefFetchCall&>(rpc); |
235 | |
236 | auto future = retrieveOwnerRRef(prf.rrefId()); |
237 | |
238 | return future->then( |
239 | [](JitFuture& future) { |
240 | SerializedPyObj result = serializePyObject(future.value()); |
241 | return withStorages( |
242 | PythonRRefFetchRet(std::move(result).toIValues()).toMessage()); |
243 | }, |
244 | c10::getCustomClassType<c10::intrusive_ptr<Message>>()); |
245 | } |
246 | |
247 | void RequestCallbackImpl::handleRRefDelete( |
248 | c10::intrusive_ptr<RRef>& rref) const { |
249 | if (rref && rref->isPyObj()) { |
250 | py::gil_scoped_acquire acquire; |
251 | rref.reset(); |
252 | } |
253 | } |
254 | |
255 | c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processRpcWithErrors( |
256 | RpcCommandBase& rpc, |
257 | const MessageType& messageType, |
258 | std::vector<c10::Stream> streams) const { |
259 | try { |
260 | return processRpc(rpc, messageType, std::move(streams)); |
261 | } catch (py::error_already_set& e) { |
262 | // Pass a dummy message ID since it will be overwritten anyways. |
263 | auto future = asFuture(handleError(e, messageType, -1)); |
264 | // There are request callback impls in Python, where Python |
265 | // exceptions could be thrown. For releasing Python exception |
266 | // py::objects, GIL must be held. |
267 | py::gil_scoped_acquire acquire; |
268 | e.restore(); // Release ownership on py::objects and also restore |
269 | // Python Error Indicator. |
270 | PyErr_Clear(); // Clear the Python Error Indicator as we has |
271 | // recorded the exception in the response message. |
272 | return future; |
273 | } catch (std::exception& e) { |
274 | // Pass a dummy message ID since it will be overwritten anyways. |
275 | return asFuture(handleError(e, messageType, -1)); |
276 | } |
277 | } |
278 | |
279 | bool RequestCallbackImpl::cudaAvailable() const { |
280 | #ifdef USE_CUDA |
281 | return true; |
282 | #else |
283 | return false; |
284 | #endif |
285 | } |
286 | |
287 | c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processRRefBackward( |
288 | RpcCommandBase& rpc) const { |
289 | auto& rrefBackwardReq = static_cast<RRefBackwardReq&>(rpc); |
290 | |
291 | auto future = retrieveOwnerRRef(rrefBackwardReq.getRRefId()); |
292 | |
293 | return future->then( |
294 | [autogradContextId = rrefBackwardReq.getAutogradContextId(), |
295 | retainGraph = rrefBackwardReq.retainGraph()](JitFuture& future) { |
296 | // Run backward (TODO: make this async?). |
297 | PyRRef::backwardOwnerRRef( |
298 | autogradContextId, retainGraph, future.value()); |
299 | |
300 | return withStorages(RRefBackwardResp().toMessage()); |
301 | }, |
302 | c10::getCustomClassType<c10::intrusive_ptr<Message>>()); |
303 | } |
304 | |
305 | c10::intrusive_ptr<JitFuture> RequestCallbackImpl::runJitFunction( |
306 | const c10::QualifiedName& name, |
307 | std::vector<at::IValue>& stack, |
308 | std::vector<c10::Stream> streams, |
309 | bool isAsyncExecution) const { |
310 | c10::MultiStreamGuard guard(streams); |
311 | c10::intrusive_ptr<JitFuture> future; |
312 | try { |
313 | // runAsync() starts in the calling thread, but may return an uncompleted |
314 | // future (though for non-async code, it will typically be completed). |
315 | // If it was async, our callback will typically be invoked by the |
316 | // continuation on an at::launch() thread. |
317 | future = PythonRpcHandler::getInstance() |
318 | .jitCompilationUnit() |
319 | ->get_function(name) |
320 | .runAsync(stack); |
321 | } catch (const std::exception&) { |
322 | return asFuture(std::current_exception()); |
323 | } |
324 | |
325 | if (isAsyncExecution) { |
326 | at::TypePtr type = future->elementType(); |
327 | if (type->kind() != at::FutureType::Kind) { |
328 | return asFuture(std::make_exception_ptr(std::runtime_error(c10::str( |
329 | "Async functions must return an IValue of Future type, but got " , |
330 | type->str())))); |
331 | } |
332 | future = future->thenAsync( |
333 | [](JitFuture& future) { return future.value().toFuture(); }, |
334 | type->cast<at::FutureType>()->getElementType()); |
335 | } |
336 | |
337 | return future; |
338 | } |
339 | |
340 | } // namespace rpc |
341 | } // namespace distributed |
342 | } // namespace torch |
343 | |