1#include <torch/csrc/distributed/rpc/request_callback_impl.h>
2
3#include <c10/util/C++17.h>
4#include <torch/csrc/autograd/profiler.h>
5#include <torch/csrc/distributed/autograd/context/container.h>
6#include <torch/csrc/distributed/autograd/context/context.h>
7#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
8#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
9#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
10#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
11#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
12#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
13#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h>
14#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h>
15#include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
16#include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h>
17#include <torch/csrc/distributed/autograd/utils.h>
18#include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
19#include <torch/csrc/distributed/rpc/py_rref.h>
20#include <torch/csrc/distributed/rpc/python_call.h>
21#include <torch/csrc/distributed/rpc/python_remote_call.h>
22#include <torch/csrc/distributed/rpc/python_resp.h>
23#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
24#include <torch/csrc/distributed/rpc/rref_context.h>
25#include <torch/csrc/distributed/rpc/rref_impl.h>
26#include <torch/csrc/distributed/rpc/rref_proto.h>
27#include <torch/csrc/distributed/rpc/script_call.h>
28#include <torch/csrc/distributed/rpc/script_remote_call.h>
29#include <torch/csrc/distributed/rpc/script_resp.h>
30#include <torch/csrc/distributed/rpc/unpickled_python_call.h>
31#include <torch/csrc/distributed/rpc/unpickled_python_remote_call.h>
32#include <torch/csrc/distributed/rpc/utils.h>
33#include <torch/csrc/jit/python/python_ivalue.h>
34
35namespace torch {
36namespace distributed {
37namespace rpc {
38
39using namespace torch::distributed::autograd;
40
41namespace {
42
43std::unique_ptr<RpcCommandBase> deserializePythonRpcCommandReference(
44 RpcCommandBase& rpc,
45 const MessageType& messageType) {
46 switch (messageType) {
47 case MessageType::PYTHON_CALL: {
48 auto& pc = static_cast<PythonCall&>(rpc);
49 return std::make_unique<UnpickledPythonCall>(
50 pc.serializedPyObj(), pc.isAsyncExecution());
51 }
52 case MessageType::PYTHON_REMOTE_CALL: {
53 auto& prc = static_cast<PythonRemoteCall&>(rpc);
54 return std::make_unique<UnpickledPythonRemoteCall>(
55 prc.serializedPyObj(),
56 prc.retRRefId(),
57 prc.retForkId(),
58 prc.isAsyncExecution());
59 }
60 case MessageType::FORWARD_AUTOGRAD_REQ: {
61 // Deserialize the wrapped RPC if it contains Python UDF
62 auto& rwa = static_cast<RpcWithAutograd&>(rpc);
63 auto& wrappedRpc = rwa.wrappedRpc();
64 auto pythonRpc = deserializePythonRpcCommandReference(
65 wrappedRpc, rwa.wrappedMessageType());
66 if (pythonRpc) {
67 rwa.setWrappedRpc(std::move(pythonRpc));
68 }
69 return nullptr;
70 }
71 case MessageType::RUN_WITH_PROFILING_REQ: {
72 // Deserialize wrapped RPC if it contains python call
73 auto& rpcWithProfilingReq = static_cast<RpcWithProfilingReq&>(rpc);
74 auto& wrappedRpc = rpcWithProfilingReq.wrappedRpc();
75 auto pythonRpc = deserializePythonRpcCommandReference(
76 wrappedRpc, rpcWithProfilingReq.wrappedMessageType());
77 if (pythonRpc) {
78 rpcWithProfilingReq.setWrappedRpc(std::move(pythonRpc));
79 }
80 return nullptr;
81 }
82 default: {
83 return nullptr;
84 }
85 }
86}
87
88SerializedPyObj serializePyObject(IValue value) {
89 auto& pythonRpcHandler = PythonRpcHandler::getInstance();
90 // Need this GIL to guard jit::toPyObj and destruct its returned
91 // py::object
92 py::gil_scoped_acquire acquire;
93 try {
94 return pythonRpcHandler.serialize(jit::toPyObject(value));
95 } catch (py::error_already_set& e) {
96 // py::error_already_set requires GIL to destruct, take special care.
97 auto err = std::runtime_error(e.what());
98 e.restore();
99 PyErr_Clear();
100 throw err;
101 }
102}
103
104} // anonymous namespace
105
106c10::intrusive_ptr<JitFuture> RequestCallbackImpl::runPythonFunction(
107 const py::object& function,
108 std::vector<c10::Stream> streams,
109 bool isAsyncExecution) const {
110 c10::MultiStreamGuard guard(streams);
111 auto& pythonRpcHandler = PythonRpcHandler::getInstance();
112 py::gil_scoped_acquire acquire;
113
114 py::object result;
115 try {
116 result = pythonRpcHandler.runPythonUdf(function);
117 } catch (py::error_already_set& e) {
118 // py::error_already_set requires GIL to destruct, take special care.
119 auto future =
120 asFuture(std::make_exception_ptr(std::runtime_error(e.what())));
121 e.restore();
122 PyErr_Clear();
123 return future;
124 } catch (std::exception& e) {
125 return asFuture(std::current_exception());
126 }
127
128 // After sync exection or failed async execution return the value as-is.
129 if (pythonRpcHandler.isRemoteException(result) || !isAsyncExecution) {
130 return asFuture(
131 c10::ivalue::ConcretePyObjectHolder::create(result),
132 at::PyObjectType::get());
133 }
134
135 try {
136 return result.cast<jit::PythonFutureWrapper&>().fut;
137 } catch (const py::cast_error& e) {
138 auto type = result.get_type();
139 auto errMsg = c10::str(
140 e.what(),
141 ". Functions decorated with @rpc.async_function must return a "
142 "torch.futures.Future object, but got ",
143 type.attr("__module__").cast<std::string>(),
144 ".",
145 type.attr("__qualname__").cast<std::string>());
146 return asFuture(std::make_exception_ptr(std::runtime_error(errMsg)));
147 }
148}
149
150std::unique_ptr<RpcCommandBase> RequestCallbackImpl::
151 deserializePythonRpcCommand(
152 std::unique_ptr<RpcCommandBase> rpc,
153 const MessageType& messageType) const {
154 auto pythonRpc = deserializePythonRpcCommandReference(*rpc, messageType);
155 return pythonRpc ? std::move(pythonRpc) : std::move(rpc);
156}
157
158c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptCall(
159 RpcCommandBase& rpc,
160 std::vector<c10::Stream> streams) const {
161 auto& scriptCall = static_cast<ScriptCall&>(rpc);
162
163 c10::intrusive_ptr<JitFuture> future;
164 if (scriptCall.hasOp()) {
165 future = runJitOperator(
166 *scriptCall.op(), scriptCall.stackRef(), std::move(streams));
167 } else {
168 future = runJitFunction(
169 scriptCall.qualifiedName(),
170 scriptCall.stackRef(),
171 std::move(streams),
172 scriptCall.isAsyncExecution());
173 }
174
175 return future->then(
176 [](JitFuture& jitFuture) {
177 return withStorages(ScriptResp(jitFuture.value()).toMessage());
178 },
179 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
180}
181
182c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonCall(
183 RpcCommandBase& rpc,
184 std::vector<c10::Stream> streams) const {
185 auto& upc = static_cast<UnpickledPythonCall&>(rpc);
186 auto future = runPythonFunction(
187 upc.pythonUdf(), std::move(streams), upc.isAsyncExecution());
188
189 return future->then(
190 [](JitFuture& future) {
191 return withStorages(
192 PythonResp(serializePyObject(future.value())).toMessage());
193 },
194 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
195}
196
197c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processScriptRemoteCall(
198 RpcCommandBase& rpc,
199 std::vector<c10::Stream> streams) const {
200 auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
201
202 c10::intrusive_ptr<JitFuture> future;
203 if (scriptRemoteCall.hasOp()) {
204 future = runJitOperator(
205 *scriptRemoteCall.op(),
206 scriptRemoteCall.stackRef(),
207 std::move(streams));
208 } else {
209 future = runJitFunction(
210 scriptRemoteCall.qualifiedName(),
211 scriptRemoteCall.stackRef(),
212 std::move(streams),
213 scriptRemoteCall.isAsyncExecution());
214 }
215
216 return assignOwnerRRef(
217 scriptRemoteCall.retRRefId(),
218 scriptRemoteCall.retForkId(),
219 std::move(future));
220}
221
222c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonRemoteCall(
223 RpcCommandBase& rpc,
224 std::vector<c10::Stream> streams) const {
225 auto& uprc = static_cast<UnpickledPythonRemoteCall&>(rpc);
226 auto future = runPythonFunction(
227 uprc.pythonUdf(), std::move(streams), uprc.isAsyncExecution());
228
229 return assignOwnerRRef(uprc.rrefId(), uprc.forkId(), std::move(future));
230}
231
232c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processPythonRRefFetchCall(
233 RpcCommandBase& rpc) const {
234 auto& prf = static_cast<PythonRRefFetchCall&>(rpc);
235
236 auto future = retrieveOwnerRRef(prf.rrefId());
237
238 return future->then(
239 [](JitFuture& future) {
240 SerializedPyObj result = serializePyObject(future.value());
241 return withStorages(
242 PythonRRefFetchRet(std::move(result).toIValues()).toMessage());
243 },
244 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
245}
246
247void RequestCallbackImpl::handleRRefDelete(
248 c10::intrusive_ptr<RRef>& rref) const {
249 if (rref && rref->isPyObj()) {
250 py::gil_scoped_acquire acquire;
251 rref.reset();
252 }
253}
254
255c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processRpcWithErrors(
256 RpcCommandBase& rpc,
257 const MessageType& messageType,
258 std::vector<c10::Stream> streams) const {
259 try {
260 return processRpc(rpc, messageType, std::move(streams));
261 } catch (py::error_already_set& e) {
262 // Pass a dummy message ID since it will be overwritten anyways.
263 auto future = asFuture(handleError(e, messageType, -1));
264 // There are request callback impls in Python, where Python
265 // exceptions could be thrown. For releasing Python exception
266 // py::objects, GIL must be held.
267 py::gil_scoped_acquire acquire;
268 e.restore(); // Release ownership on py::objects and also restore
269 // Python Error Indicator.
270 PyErr_Clear(); // Clear the Python Error Indicator as we has
271 // recorded the exception in the response message.
272 return future;
273 } catch (std::exception& e) {
274 // Pass a dummy message ID since it will be overwritten anyways.
275 return asFuture(handleError(e, messageType, -1));
276 }
277}
278
279bool RequestCallbackImpl::cudaAvailable() const {
280#ifdef USE_CUDA
281 return true;
282#else
283 return false;
284#endif
285}
286
287c10::intrusive_ptr<JitFuture> RequestCallbackImpl::processRRefBackward(
288 RpcCommandBase& rpc) const {
289 auto& rrefBackwardReq = static_cast<RRefBackwardReq&>(rpc);
290
291 auto future = retrieveOwnerRRef(rrefBackwardReq.getRRefId());
292
293 return future->then(
294 [autogradContextId = rrefBackwardReq.getAutogradContextId(),
295 retainGraph = rrefBackwardReq.retainGraph()](JitFuture& future) {
296 // Run backward (TODO: make this async?).
297 PyRRef::backwardOwnerRRef(
298 autogradContextId, retainGraph, future.value());
299
300 return withStorages(RRefBackwardResp().toMessage());
301 },
302 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
303}
304
305c10::intrusive_ptr<JitFuture> RequestCallbackImpl::runJitFunction(
306 const c10::QualifiedName& name,
307 std::vector<at::IValue>& stack,
308 std::vector<c10::Stream> streams,
309 bool isAsyncExecution) const {
310 c10::MultiStreamGuard guard(streams);
311 c10::intrusive_ptr<JitFuture> future;
312 try {
313 // runAsync() starts in the calling thread, but may return an uncompleted
314 // future (though for non-async code, it will typically be completed).
315 // If it was async, our callback will typically be invoked by the
316 // continuation on an at::launch() thread.
317 future = PythonRpcHandler::getInstance()
318 .jitCompilationUnit()
319 ->get_function(name)
320 .runAsync(stack);
321 } catch (const std::exception&) {
322 return asFuture(std::current_exception());
323 }
324
325 if (isAsyncExecution) {
326 at::TypePtr type = future->elementType();
327 if (type->kind() != at::FutureType::Kind) {
328 return asFuture(std::make_exception_ptr(std::runtime_error(c10::str(
329 "Async functions must return an IValue of Future type, but got ",
330 type->str()))));
331 }
332 future = future->thenAsync(
333 [](JitFuture& future) { return future.value().toFuture(); },
334 type->cast<at::FutureType>()->getElementType());
335 }
336
337 return future;
338}
339
340} // namespace rpc
341} // namespace distributed
342} // namespace torch
343