1 | #include <ATen/ThreadLocalState.h> |
2 | #include <c10/util/C++17.h> |
3 | #include <torch/csrc/distributed/autograd/context/container.h> |
4 | #include <torch/csrc/distributed/autograd/utils.h> |
5 | #include <torch/csrc/distributed/rpc/message.h> |
6 | #include <torch/csrc/distributed/rpc/python_call.h> |
7 | #include <torch/csrc/distributed/rpc/python_functions.h> |
8 | #include <torch/csrc/distributed/rpc/python_remote_call.h> |
9 | #include <torch/csrc/distributed/rpc/python_resp.h> |
10 | #include <torch/csrc/distributed/rpc/python_rpc_handler.h> |
11 | #include <torch/csrc/distributed/rpc/rref_context.h> |
12 | #include <torch/csrc/distributed/rpc/rref_proto.h> |
13 | #include <torch/csrc/distributed/rpc/script_call.h> |
14 | #include <torch/csrc/distributed/rpc/script_remote_call.h> |
15 | #include <torch/csrc/distributed/rpc/script_resp.h> |
16 | #include <torch/csrc/distributed/rpc/torchscript_functions.h> |
17 | #include <torch/csrc/distributed/rpc/utils.h> |
18 | #include <torch/csrc/jit/runtime/operator.h> |
19 | #include <torch/csrc/utils/python_compat.h> |
20 | #include <exception> |
21 | |
22 | namespace torch { |
23 | namespace distributed { |
24 | namespace rpc { |
25 | |
26 | namespace { |
27 | |
28 | IValue toPyIValue(const Message& message) { |
29 | MessageType msgType = message.type(); |
30 | auto response = deserializeResponse(message, msgType); |
31 | switch (msgType) { |
32 | case MessageType::SCRIPT_RET: { |
33 | auto& ret = static_cast<ScriptResp&>(*response); |
34 | Stack stack; |
35 | stack.push_back(ret.value()); |
36 | // Need GIL to guard createPyObjectForStack() and its returned |
37 | // py::object |
38 | py::gil_scoped_acquire acquire; |
39 | return jit::toIValue( |
40 | torch::jit::createPyObjectForStack(std::move(stack)), |
41 | PyObjectType::get()); |
42 | } |
43 | case MessageType::PYTHON_RET: { |
44 | // TODO: Try to avoid a copy here. |
45 | auto& resp = static_cast<PythonResp&>(*response); |
46 | auto& pythonRpcHandler = PythonRpcHandler::getInstance(); |
47 | // Need GIL to destruct the py::object returned by deserialize() |
48 | py::gil_scoped_acquire acquire; |
49 | py::object value = pythonRpcHandler.deserialize(resp.serializedPyObj()); |
50 | pythonRpcHandler.handleException(value); |
51 | return jit::toIValue(value, PyObjectType::get()); |
52 | } |
53 | default: { |
54 | TORCH_CHECK(false, "Unrecognized response message type " , msgType); |
55 | } |
56 | } |
57 | } |
58 | |
59 | std::shared_ptr<Operator> matchBuiltinOp( |
60 | const std::string& opName, |
61 | const py::args& args, |
62 | const py::kwargs& kwargs, |
63 | Stack& stack) { |
64 | Symbol symbol = Symbol::fromQualString(opName); |
65 | |
66 | std::shared_ptr<jit::Operator> matchedOperator; |
67 | if (symbol.is_aten()) { |
68 | // Prefer C10 ops so that they go through C10 dispatch. We expect the |
69 | // total # of possible overloaded ops (i.e. size of below ops list) to be |
70 | // small (i.e. it is 10 for torch.add) so a worst-case linear search should |
71 | // not incur significant extra overhead. |
72 | auto ops = torch::jit::getAllOperatorsFor(symbol); |
73 | std::vector<std::shared_ptr<torch::jit::Operator>> c10OpsForSymbol; |
74 | for (auto it = ops.begin(); it != ops.end();) { |
75 | std::shared_ptr<jit::Operator> op = *it; |
76 | if (op->isC10Op()) { |
77 | c10OpsForSymbol.emplace_back(std::move(op)); |
78 | it = ops.erase(it); |
79 | } else { |
80 | ++it; |
81 | } |
82 | } |
83 | |
84 | // Don't throw on failures in this call, since we are not examining on all |
85 | // operators here, and the matched operator may indeed not be a c10 op. |
86 | std::pair<std::shared_ptr<torch::jit::Operator>, torch::jit::Stack> |
87 | opWithStack; |
88 | try { |
89 | opWithStack = torch::jit::getOpWithStack(c10OpsForSymbol, args, kwargs); |
90 | } catch (const std::runtime_error& e) { |
91 | opWithStack = torch::jit::getOpWithStack(ops, args, kwargs); |
92 | } |
93 | matchedOperator = std::get<0>(opWithStack); |
94 | stack = std::get<1>(opWithStack); |
95 | } |
96 | |
97 | // We should never hit this path, since if !matchedOperator, then the last |
98 | // call to getOpWithStack should have thrown. |
99 | TORCH_CHECK( |
100 | matchedOperator != nullptr, |
101 | "Failed to match operator name " , |
102 | opName, |
103 | " and arguments " |
104 | "(args: " , |
105 | args, |
106 | ", kwargs: " , |
107 | kwargs, |
108 | ") to a builtin operator" ); |
109 | |
110 | return matchedOperator; |
111 | } |
112 | |
113 | c10::intrusive_ptr<JitFuture> sendPythonRemoteCall( |
114 | const WorkerInfo& dst, |
115 | SerializedPyObj serializedPyObj, |
116 | const IValue& rrefId, |
117 | const IValue& forkId, |
118 | const float rpcTimeoutSeconds, |
119 | const bool isAsyncExecution) { |
120 | auto pythonRemoteCall = std::make_unique<PythonRemoteCall>( |
121 | std::move(serializedPyObj), rrefId, forkId, isAsyncExecution); |
122 | |
123 | // set forceGradRecording to true as even if the args does not contain any |
124 | // tensor, the return value might still contain tensors. |
125 | auto agent = RpcAgent::getCurrentRpcAgent(); |
126 | return torch::distributed::autograd::sendMessageWithAutograd( |
127 | *agent, |
128 | dst, |
129 | std::move(*pythonRemoteCall).toMessage(), |
130 | true /*forceGradRecording*/, |
131 | rpcTimeoutSeconds); |
132 | } |
133 | |
134 | } // namespace |
135 | |
136 | using namespace torch::distributed::autograd; |
137 | |
138 | c10::intrusive_ptr<JitFuture> toPyJitFuture( |
139 | const c10::intrusive_ptr<JitFuture>& messageJitFuture, |
140 | bool hasValue) { |
141 | if (hasValue) { |
142 | auto child = messageJitFuture->createInstance(PyObjectType::get()); |
143 | messageJitFuture->addCallback( |
144 | at::wrapPropagateTLSState([child](JitFuture& future) { |
145 | if (future.hasError()) { |
146 | child->setError(future.exception_ptr()); |
147 | } else { |
148 | const Message& message = *future.value().toCustomClass<Message>(); |
149 | |
150 | // toPyIValue might throw and we need to record the appropriate |
151 | // exception. |
152 | IValue ivalue; |
153 | try { |
154 | ivalue = toPyIValue(message); |
155 | } catch (py::error_already_set& e) { |
156 | py::gil_scoped_acquire acquire; |
157 | // FIXME: this is a temporary solution to add a special-case for |
158 | // ValueError and TypeError, as those are already used in our |
159 | // tests. We should have a more comprehensive coverage for other |
160 | // types of exceptions as well. |
161 | if (e.matches(PyExc_ValueError)) { |
162 | child->setErrorIfNeeded( |
163 | std::make_exception_ptr(pybind11::value_error(e.what()))); |
164 | } else if (e.matches(PyExc_TypeError)) { |
165 | child->setErrorIfNeeded( |
166 | std::make_exception_ptr(pybind11::type_error(e.what()))); |
167 | } else { |
168 | // py::error_already_set requires GIL to destruct, take special |
169 | // care. |
170 | child->setErrorIfNeeded( |
171 | std::make_exception_ptr(std::runtime_error(e.what()))); |
172 | } |
173 | e.restore(); |
174 | PyErr_Clear(); |
175 | return; |
176 | } catch (std::exception& e) { |
177 | child->setErrorIfNeeded(std::current_exception()); |
178 | return; |
179 | } |
180 | |
181 | child->markCompleted(ivalue, future.storages()); |
182 | } |
183 | })); |
184 | return child; |
185 | } else { |
186 | return messageJitFuture->then( |
187 | at::wrapPropagateTLSState([](JitFuture& future) { |
188 | if (future.hasError()) { |
189 | std::rethrow_exception(future.exception_ptr()); |
190 | } else { |
191 | return IValue(); |
192 | } |
193 | }), |
194 | NoneType::get()); |
195 | } |
196 | } |
197 | |
198 | c10::intrusive_ptr<JitFuture> pyRpcBuiltin( |
199 | const WorkerInfo& dst, |
200 | const std::string& opName, |
201 | const py::args& args, |
202 | const py::kwargs& kwargs, |
203 | const float rpcTimeoutSeconds) { |
204 | DCHECK(PyGILState_Check()); |
205 | Stack stack; |
206 | auto op = matchBuiltinOp(opName, args, kwargs, stack); |
207 | // Release GIL since args and kwargs processing is done. |
208 | py::gil_scoped_release release; |
209 | auto scriptCall = std::make_unique<ScriptCall>(op, std::move(stack)); |
210 | auto agent = RpcAgent::getCurrentRpcAgent(); |
211 | return toPyJitFuture(sendMessageWithAutograd( |
212 | *agent, |
213 | dst, |
214 | std::move(*scriptCall).toMessage(), |
215 | false, |
216 | rpcTimeoutSeconds)); |
217 | } |
218 | |
219 | c10::intrusive_ptr<JitFuture> pyRpcPythonUdf( |
220 | const WorkerInfo& dst, |
221 | std::string& pickledPythonUDF, |
222 | std::vector<torch::Tensor>& tensors, |
223 | const float rpcTimeoutSeconds, |
224 | const bool isAsyncExecution) { |
225 | DCHECK(!PyGILState_Check()); |
226 | auto serializedPyObj = |
227 | SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors)); |
228 | auto pythonCall = std::make_unique<PythonCall>( |
229 | std::move(serializedPyObj), isAsyncExecution); |
230 | |
231 | auto agent = RpcAgent::getCurrentRpcAgent(); |
232 | return toPyJitFuture(sendMessageWithAutograd( |
233 | *agent, |
234 | dst, |
235 | std::move(*pythonCall).toMessage(), |
236 | true /*forceGradRecording*/, |
237 | rpcTimeoutSeconds)); |
238 | } |
239 | |
240 | c10::intrusive_ptr<JitFuture> pyRpcTorchscript( |
241 | const std::string& dstWorkerName, |
242 | const std::string& qualifiedNameStr, |
243 | const py::tuple& argsTuple, |
244 | const py::dict& kwargsDict, |
245 | const float rpcTimeoutSeconds, |
246 | const bool isAsyncExecution) { |
247 | // No need to catch exception here, if function can not be found, |
248 | // exception will be thrown in get_function() call; if args do not match |
249 | // with function schema, exception will be thrown in |
250 | // createStackForSchema() call. |
251 | DCHECK(!PyGILState_Check()); |
252 | const c10::QualifiedName qualifiedName(qualifiedNameStr); |
253 | auto functionSchema = PythonRpcHandler::getInstance() |
254 | .jitCompilationUnit() |
255 | ->get_function(qualifiedName) |
256 | .getSchema(); |
257 | Stack stack; |
258 | { |
259 | // Acquire GIL for py::args and py::kwargs processing. |
260 | py::gil_scoped_acquire acquire; |
261 | stack = torch::jit::createStackForSchema( |
262 | functionSchema, |
263 | argsTuple.cast<py::args>(), |
264 | kwargsDict.cast<py::kwargs>(), |
265 | c10::nullopt); |
266 | } |
267 | DCHECK(!PyGILState_Check()); |
268 | c10::intrusive_ptr<c10::ivalue::Future> fut = rpcTorchscript( |
269 | dstWorkerName, |
270 | qualifiedName, |
271 | functionSchema, |
272 | stack, |
273 | rpcTimeoutSeconds, |
274 | isAsyncExecution); |
275 | return fut; |
276 | } |
277 | |
278 | PyRRef pyRemoteBuiltin( |
279 | const WorkerInfo& dst, |
280 | const std::string& opName, |
281 | const float rpcTimeoutSeconds, |
282 | const py::args& args, |
283 | const py::kwargs& kwargs) { |
284 | DCHECK(PyGILState_Check()); |
285 | Stack stack; |
286 | auto op = matchBuiltinOp(opName, args, kwargs, stack); |
287 | // Release GIL since args and kwargs processing is done. |
288 | py::gil_scoped_release release; |
289 | TypePtr returnType = op->schema().returns()[0].type(); |
290 | |
291 | auto& ctx = RRefContext::getInstance(); |
292 | auto agent = RpcAgent::getCurrentRpcAgent(); |
293 | |
294 | if (ctx.getWorkerId() != dst.id_) { |
295 | auto userRRef = ctx.createUserRRef(dst.id_, returnType); |
296 | |
297 | auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>( |
298 | op, std::move(stack), userRRef->rrefId(), userRRef->forkId()); |
299 | |
300 | auto jitFuture = sendMessageWithAutograd( |
301 | *agent, |
302 | dst, |
303 | std::move(*scriptRemoteCall).toMessage(), |
304 | /*forceGradRecord */ false, |
305 | /* timeout */ rpcTimeoutSeconds); |
306 | |
307 | userRRef->registerOwnerCreationFuture(jitFuture); |
308 | ctx.addPendingUser(userRRef->forkId(), userRRef); |
309 | jitFuture->addCallback(at::wrapPropagateTLSState( |
310 | [forkId{userRRef->forkId()}](JitFuture& future) { |
311 | callback::confirmPendingUser(future, forkId); |
312 | })); |
313 | return PyRRef(userRRef); |
314 | } else { |
315 | auto ownerRRef = ctx.createOwnerRRef(returnType); |
316 | // prevent this owner RRef being deleted due to other forks |
317 | ctx.addSelfAsFork(ownerRRef); |
318 | |
319 | auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>( |
320 | op, std::move(stack), ownerRRef->rrefId(), ownerRRef->rrefId()); |
321 | auto jitFuture = sendMessageWithAutograd( |
322 | *agent, |
323 | dst, |
324 | std::move(*scriptRemoteCall).toMessage(), |
325 | /* forceGradRecord */ false, |
326 | /* timeout */ rpcTimeoutSeconds); |
327 | |
328 | ownerRRef->registerOwnerCreationFuture(jitFuture); |
329 | // Builtin operators does not return py::object, and hence does not require |
330 | // GIL for destructing the potentially deleted OwerRRef. |
331 | jitFuture->addCallback(at::wrapPropagateTLSState( |
332 | [ownerRRefId = ownerRRef->rrefId()](JitFuture& future) { |
333 | callback::finishCreatingOwnerRRef(future, ownerRRefId); |
334 | })); |
335 | return PyRRef(ownerRRef); |
336 | } |
337 | } |
338 | |
339 | PyRRef pyRemotePythonUdf( |
340 | const WorkerInfo& dst, |
341 | std::string& pickledPythonUDF, |
342 | std::vector<torch::Tensor>& tensors, |
343 | const float rpcTimeoutSeconds, |
344 | const bool isAsyncExecution) { |
345 | DCHECK(!PyGILState_Check()); |
346 | auto& ctx = RRefContext::getInstance(); |
347 | auto serializedPyObj = |
348 | SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors)); |
349 | |
350 | if (ctx.getWorkerId() != dst.id_) { |
351 | auto userRRef = ctx.createUserRRef(dst.id_, PyObjectType::get()); |
352 | auto jitFuture = sendPythonRemoteCall( |
353 | dst, |
354 | std::move(serializedPyObj), |
355 | userRRef->rrefId().toIValue(), |
356 | userRRef->forkId().toIValue(), |
357 | rpcTimeoutSeconds, |
358 | isAsyncExecution); |
359 | |
360 | userRRef->registerOwnerCreationFuture(jitFuture); |
361 | ctx.addPendingUser(userRRef->forkId(), userRRef); |
362 | jitFuture->addCallback(at::wrapPropagateTLSState( |
363 | [forkId{userRRef->forkId()}](JitFuture& future) { |
364 | callback::confirmPendingUser(future, forkId); |
365 | })); |
366 | return PyRRef(userRRef); |
367 | } else { |
368 | // Sending remote message to self |
369 | auto ownerRRef = ctx.createOwnerRRef(PyObjectType::get()); |
370 | // prevent this owner RRef being deleted due to other forks |
371 | ctx.addSelfAsFork(ownerRRef); |
372 | auto jitFuture = sendPythonRemoteCall( |
373 | dst, |
374 | std::move(serializedPyObj), |
375 | ownerRRef->rrefId().toIValue(), |
376 | ownerRRef->rrefId().toIValue(), |
377 | rpcTimeoutSeconds, |
378 | isAsyncExecution); |
379 | |
380 | ownerRRef->registerOwnerCreationFuture(jitFuture); |
381 | jitFuture->addCallback(at::wrapPropagateTLSState( |
382 | [ownerRRefId = ownerRRef->rrefId()](JitFuture& future) { |
383 | auto deletedRRef = |
384 | callback::finishCreatingOwnerRRef(future, ownerRRefId); |
385 | if (deletedRRef && deletedRRef->isPyObj()) { |
386 | py::gil_scoped_acquire ag; |
387 | deletedRRef.reset(); |
388 | } |
389 | })); |
390 | return PyRRef(ownerRRef); |
391 | } |
392 | } |
393 | |
394 | PyRRef pyRemoteTorchscript( |
395 | const std::string& dstWorkerName, |
396 | const std::string& qualifiedNameStr, |
397 | const float rpcTimeoutSeconds, |
398 | const bool isAsyncExecution, |
399 | const py::args& args, |
400 | const py::kwargs& kwargs) { |
401 | DCHECK(!PyGILState_Check()); |
402 | auto qualifiedName = c10::QualifiedName(qualifiedNameStr); |
403 | auto functionSchema = PythonRpcHandler::getInstance() |
404 | .jitCompilationUnit() |
405 | ->get_function(qualifiedName) |
406 | .getSchema(); |
407 | Stack stack; |
408 | { |
409 | // Acquire GIL for py::args and py::kwargs processing. |
410 | py::gil_scoped_acquire ag; |
411 | stack = torch::jit::createStackForSchema( |
412 | functionSchema, args, kwargs, c10::nullopt); |
413 | } |
414 | DCHECK(!PyGILState_Check()); |
415 | auto rrefPtr = remoteTorchscript( |
416 | dstWorkerName, |
417 | qualifiedName, |
418 | functionSchema, |
419 | stack, |
420 | rpcTimeoutSeconds, |
421 | isAsyncExecution); |
422 | return PyRRef(rrefPtr); |
423 | } |
424 | |
425 | } // namespace rpc |
426 | } // namespace distributed |
427 | } // namespace torch |
428 | |