1#include <ATen/ThreadLocalState.h>
2#include <c10/util/C++17.h>
3#include <torch/csrc/distributed/autograd/context/container.h>
4#include <torch/csrc/distributed/autograd/utils.h>
5#include <torch/csrc/distributed/rpc/message.h>
6#include <torch/csrc/distributed/rpc/python_call.h>
7#include <torch/csrc/distributed/rpc/python_functions.h>
8#include <torch/csrc/distributed/rpc/python_remote_call.h>
9#include <torch/csrc/distributed/rpc/python_resp.h>
10#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
11#include <torch/csrc/distributed/rpc/rref_context.h>
12#include <torch/csrc/distributed/rpc/rref_proto.h>
13#include <torch/csrc/distributed/rpc/script_call.h>
14#include <torch/csrc/distributed/rpc/script_remote_call.h>
15#include <torch/csrc/distributed/rpc/script_resp.h>
16#include <torch/csrc/distributed/rpc/torchscript_functions.h>
17#include <torch/csrc/distributed/rpc/utils.h>
18#include <torch/csrc/jit/runtime/operator.h>
19#include <torch/csrc/utils/python_compat.h>
20#include <exception>
21
22namespace torch {
23namespace distributed {
24namespace rpc {
25
26namespace {
27
28IValue toPyIValue(const Message& message) {
29 MessageType msgType = message.type();
30 auto response = deserializeResponse(message, msgType);
31 switch (msgType) {
32 case MessageType::SCRIPT_RET: {
33 auto& ret = static_cast<ScriptResp&>(*response);
34 Stack stack;
35 stack.push_back(ret.value());
36 // Need GIL to guard createPyObjectForStack() and its returned
37 // py::object
38 py::gil_scoped_acquire acquire;
39 return jit::toIValue(
40 torch::jit::createPyObjectForStack(std::move(stack)),
41 PyObjectType::get());
42 }
43 case MessageType::PYTHON_RET: {
44 // TODO: Try to avoid a copy here.
45 auto& resp = static_cast<PythonResp&>(*response);
46 auto& pythonRpcHandler = PythonRpcHandler::getInstance();
47 // Need GIL to destruct the py::object returned by deserialize()
48 py::gil_scoped_acquire acquire;
49 py::object value = pythonRpcHandler.deserialize(resp.serializedPyObj());
50 pythonRpcHandler.handleException(value);
51 return jit::toIValue(value, PyObjectType::get());
52 }
53 default: {
54 TORCH_CHECK(false, "Unrecognized response message type ", msgType);
55 }
56 }
57}
58
59std::shared_ptr<Operator> matchBuiltinOp(
60 const std::string& opName,
61 const py::args& args,
62 const py::kwargs& kwargs,
63 Stack& stack) {
64 Symbol symbol = Symbol::fromQualString(opName);
65
66 std::shared_ptr<jit::Operator> matchedOperator;
67 if (symbol.is_aten()) {
68 // Prefer C10 ops so that they go through C10 dispatch. We expect the
69 // total # of possible overloaded ops (i.e. size of below ops list) to be
70 // small (i.e. it is 10 for torch.add) so a worst-case linear search should
71 // not incur significant extra overhead.
72 auto ops = torch::jit::getAllOperatorsFor(symbol);
73 std::vector<std::shared_ptr<torch::jit::Operator>> c10OpsForSymbol;
74 for (auto it = ops.begin(); it != ops.end();) {
75 std::shared_ptr<jit::Operator> op = *it;
76 if (op->isC10Op()) {
77 c10OpsForSymbol.emplace_back(std::move(op));
78 it = ops.erase(it);
79 } else {
80 ++it;
81 }
82 }
83
84 // Don't throw on failures in this call, since we are not examining on all
85 // operators here, and the matched operator may indeed not be a c10 op.
86 std::pair<std::shared_ptr<torch::jit::Operator>, torch::jit::Stack>
87 opWithStack;
88 try {
89 opWithStack = torch::jit::getOpWithStack(c10OpsForSymbol, args, kwargs);
90 } catch (const std::runtime_error& e) {
91 opWithStack = torch::jit::getOpWithStack(ops, args, kwargs);
92 }
93 matchedOperator = std::get<0>(opWithStack);
94 stack = std::get<1>(opWithStack);
95 }
96
97 // We should never hit this path, since if !matchedOperator, then the last
98 // call to getOpWithStack should have thrown.
99 TORCH_CHECK(
100 matchedOperator != nullptr,
101 "Failed to match operator name ",
102 opName,
103 " and arguments "
104 "(args: ",
105 args,
106 ", kwargs: ",
107 kwargs,
108 ") to a builtin operator");
109
110 return matchedOperator;
111}
112
113c10::intrusive_ptr<JitFuture> sendPythonRemoteCall(
114 const WorkerInfo& dst,
115 SerializedPyObj serializedPyObj,
116 const IValue& rrefId,
117 const IValue& forkId,
118 const float rpcTimeoutSeconds,
119 const bool isAsyncExecution) {
120 auto pythonRemoteCall = std::make_unique<PythonRemoteCall>(
121 std::move(serializedPyObj), rrefId, forkId, isAsyncExecution);
122
123 // set forceGradRecording to true as even if the args does not contain any
124 // tensor, the return value might still contain tensors.
125 auto agent = RpcAgent::getCurrentRpcAgent();
126 return torch::distributed::autograd::sendMessageWithAutograd(
127 *agent,
128 dst,
129 std::move(*pythonRemoteCall).toMessage(),
130 true /*forceGradRecording*/,
131 rpcTimeoutSeconds);
132}
133
134} // namespace
135
136using namespace torch::distributed::autograd;
137
138c10::intrusive_ptr<JitFuture> toPyJitFuture(
139 const c10::intrusive_ptr<JitFuture>& messageJitFuture,
140 bool hasValue) {
141 if (hasValue) {
142 auto child = messageJitFuture->createInstance(PyObjectType::get());
143 messageJitFuture->addCallback(
144 at::wrapPropagateTLSState([child](JitFuture& future) {
145 if (future.hasError()) {
146 child->setError(future.exception_ptr());
147 } else {
148 const Message& message = *future.value().toCustomClass<Message>();
149
150 // toPyIValue might throw and we need to record the appropriate
151 // exception.
152 IValue ivalue;
153 try {
154 ivalue = toPyIValue(message);
155 } catch (py::error_already_set& e) {
156 py::gil_scoped_acquire acquire;
157 // FIXME: this is a temporary solution to add a special-case for
158 // ValueError and TypeError, as those are already used in our
159 // tests. We should have a more comprehensive coverage for other
160 // types of exceptions as well.
161 if (e.matches(PyExc_ValueError)) {
162 child->setErrorIfNeeded(
163 std::make_exception_ptr(pybind11::value_error(e.what())));
164 } else if (e.matches(PyExc_TypeError)) {
165 child->setErrorIfNeeded(
166 std::make_exception_ptr(pybind11::type_error(e.what())));
167 } else {
168 // py::error_already_set requires GIL to destruct, take special
169 // care.
170 child->setErrorIfNeeded(
171 std::make_exception_ptr(std::runtime_error(e.what())));
172 }
173 e.restore();
174 PyErr_Clear();
175 return;
176 } catch (std::exception& e) {
177 child->setErrorIfNeeded(std::current_exception());
178 return;
179 }
180
181 child->markCompleted(ivalue, future.storages());
182 }
183 }));
184 return child;
185 } else {
186 return messageJitFuture->then(
187 at::wrapPropagateTLSState([](JitFuture& future) {
188 if (future.hasError()) {
189 std::rethrow_exception(future.exception_ptr());
190 } else {
191 return IValue();
192 }
193 }),
194 NoneType::get());
195 }
196}
197
198c10::intrusive_ptr<JitFuture> pyRpcBuiltin(
199 const WorkerInfo& dst,
200 const std::string& opName,
201 const py::args& args,
202 const py::kwargs& kwargs,
203 const float rpcTimeoutSeconds) {
204 DCHECK(PyGILState_Check());
205 Stack stack;
206 auto op = matchBuiltinOp(opName, args, kwargs, stack);
207 // Release GIL since args and kwargs processing is done.
208 py::gil_scoped_release release;
209 auto scriptCall = std::make_unique<ScriptCall>(op, std::move(stack));
210 auto agent = RpcAgent::getCurrentRpcAgent();
211 return toPyJitFuture(sendMessageWithAutograd(
212 *agent,
213 dst,
214 std::move(*scriptCall).toMessage(),
215 false,
216 rpcTimeoutSeconds));
217}
218
219c10::intrusive_ptr<JitFuture> pyRpcPythonUdf(
220 const WorkerInfo& dst,
221 std::string& pickledPythonUDF,
222 std::vector<torch::Tensor>& tensors,
223 const float rpcTimeoutSeconds,
224 const bool isAsyncExecution) {
225 DCHECK(!PyGILState_Check());
226 auto serializedPyObj =
227 SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
228 auto pythonCall = std::make_unique<PythonCall>(
229 std::move(serializedPyObj), isAsyncExecution);
230
231 auto agent = RpcAgent::getCurrentRpcAgent();
232 return toPyJitFuture(sendMessageWithAutograd(
233 *agent,
234 dst,
235 std::move(*pythonCall).toMessage(),
236 true /*forceGradRecording*/,
237 rpcTimeoutSeconds));
238}
239
240c10::intrusive_ptr<JitFuture> pyRpcTorchscript(
241 const std::string& dstWorkerName,
242 const std::string& qualifiedNameStr,
243 const py::tuple& argsTuple,
244 const py::dict& kwargsDict,
245 const float rpcTimeoutSeconds,
246 const bool isAsyncExecution) {
247 // No need to catch exception here, if function can not be found,
248 // exception will be thrown in get_function() call; if args do not match
249 // with function schema, exception will be thrown in
250 // createStackForSchema() call.
251 DCHECK(!PyGILState_Check());
252 const c10::QualifiedName qualifiedName(qualifiedNameStr);
253 auto functionSchema = PythonRpcHandler::getInstance()
254 .jitCompilationUnit()
255 ->get_function(qualifiedName)
256 .getSchema();
257 Stack stack;
258 {
259 // Acquire GIL for py::args and py::kwargs processing.
260 py::gil_scoped_acquire acquire;
261 stack = torch::jit::createStackForSchema(
262 functionSchema,
263 argsTuple.cast<py::args>(),
264 kwargsDict.cast<py::kwargs>(),
265 c10::nullopt);
266 }
267 DCHECK(!PyGILState_Check());
268 c10::intrusive_ptr<c10::ivalue::Future> fut = rpcTorchscript(
269 dstWorkerName,
270 qualifiedName,
271 functionSchema,
272 stack,
273 rpcTimeoutSeconds,
274 isAsyncExecution);
275 return fut;
276}
277
278PyRRef pyRemoteBuiltin(
279 const WorkerInfo& dst,
280 const std::string& opName,
281 const float rpcTimeoutSeconds,
282 const py::args& args,
283 const py::kwargs& kwargs) {
284 DCHECK(PyGILState_Check());
285 Stack stack;
286 auto op = matchBuiltinOp(opName, args, kwargs, stack);
287 // Release GIL since args and kwargs processing is done.
288 py::gil_scoped_release release;
289 TypePtr returnType = op->schema().returns()[0].type();
290
291 auto& ctx = RRefContext::getInstance();
292 auto agent = RpcAgent::getCurrentRpcAgent();
293
294 if (ctx.getWorkerId() != dst.id_) {
295 auto userRRef = ctx.createUserRRef(dst.id_, returnType);
296
297 auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
298 op, std::move(stack), userRRef->rrefId(), userRRef->forkId());
299
300 auto jitFuture = sendMessageWithAutograd(
301 *agent,
302 dst,
303 std::move(*scriptRemoteCall).toMessage(),
304 /*forceGradRecord */ false,
305 /* timeout */ rpcTimeoutSeconds);
306
307 userRRef->registerOwnerCreationFuture(jitFuture);
308 ctx.addPendingUser(userRRef->forkId(), userRRef);
309 jitFuture->addCallback(at::wrapPropagateTLSState(
310 [forkId{userRRef->forkId()}](JitFuture& future) {
311 callback::confirmPendingUser(future, forkId);
312 }));
313 return PyRRef(userRRef);
314 } else {
315 auto ownerRRef = ctx.createOwnerRRef(returnType);
316 // prevent this owner RRef being deleted due to other forks
317 ctx.addSelfAsFork(ownerRRef);
318
319 auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
320 op, std::move(stack), ownerRRef->rrefId(), ownerRRef->rrefId());
321 auto jitFuture = sendMessageWithAutograd(
322 *agent,
323 dst,
324 std::move(*scriptRemoteCall).toMessage(),
325 /* forceGradRecord */ false,
326 /* timeout */ rpcTimeoutSeconds);
327
328 ownerRRef->registerOwnerCreationFuture(jitFuture);
329 // Builtin operators does not return py::object, and hence does not require
330 // GIL for destructing the potentially deleted OwerRRef.
331 jitFuture->addCallback(at::wrapPropagateTLSState(
332 [ownerRRefId = ownerRRef->rrefId()](JitFuture& future) {
333 callback::finishCreatingOwnerRRef(future, ownerRRefId);
334 }));
335 return PyRRef(ownerRRef);
336 }
337}
338
339PyRRef pyRemotePythonUdf(
340 const WorkerInfo& dst,
341 std::string& pickledPythonUDF,
342 std::vector<torch::Tensor>& tensors,
343 const float rpcTimeoutSeconds,
344 const bool isAsyncExecution) {
345 DCHECK(!PyGILState_Check());
346 auto& ctx = RRefContext::getInstance();
347 auto serializedPyObj =
348 SerializedPyObj(std::move(pickledPythonUDF), std::move(tensors));
349
350 if (ctx.getWorkerId() != dst.id_) {
351 auto userRRef = ctx.createUserRRef(dst.id_, PyObjectType::get());
352 auto jitFuture = sendPythonRemoteCall(
353 dst,
354 std::move(serializedPyObj),
355 userRRef->rrefId().toIValue(),
356 userRRef->forkId().toIValue(),
357 rpcTimeoutSeconds,
358 isAsyncExecution);
359
360 userRRef->registerOwnerCreationFuture(jitFuture);
361 ctx.addPendingUser(userRRef->forkId(), userRRef);
362 jitFuture->addCallback(at::wrapPropagateTLSState(
363 [forkId{userRRef->forkId()}](JitFuture& future) {
364 callback::confirmPendingUser(future, forkId);
365 }));
366 return PyRRef(userRRef);
367 } else {
368 // Sending remote message to self
369 auto ownerRRef = ctx.createOwnerRRef(PyObjectType::get());
370 // prevent this owner RRef being deleted due to other forks
371 ctx.addSelfAsFork(ownerRRef);
372 auto jitFuture = sendPythonRemoteCall(
373 dst,
374 std::move(serializedPyObj),
375 ownerRRef->rrefId().toIValue(),
376 ownerRRef->rrefId().toIValue(),
377 rpcTimeoutSeconds,
378 isAsyncExecution);
379
380 ownerRRef->registerOwnerCreationFuture(jitFuture);
381 jitFuture->addCallback(at::wrapPropagateTLSState(
382 [ownerRRefId = ownerRRef->rrefId()](JitFuture& future) {
383 auto deletedRRef =
384 callback::finishCreatingOwnerRRef(future, ownerRRefId);
385 if (deletedRRef && deletedRRef->isPyObj()) {
386 py::gil_scoped_acquire ag;
387 deletedRRef.reset();
388 }
389 }));
390 return PyRRef(ownerRRef);
391 }
392}
393
394PyRRef pyRemoteTorchscript(
395 const std::string& dstWorkerName,
396 const std::string& qualifiedNameStr,
397 const float rpcTimeoutSeconds,
398 const bool isAsyncExecution,
399 const py::args& args,
400 const py::kwargs& kwargs) {
401 DCHECK(!PyGILState_Check());
402 auto qualifiedName = c10::QualifiedName(qualifiedNameStr);
403 auto functionSchema = PythonRpcHandler::getInstance()
404 .jitCompilationUnit()
405 ->get_function(qualifiedName)
406 .getSchema();
407 Stack stack;
408 {
409 // Acquire GIL for py::args and py::kwargs processing.
410 py::gil_scoped_acquire ag;
411 stack = torch::jit::createStackForSchema(
412 functionSchema, args, kwargs, c10::nullopt);
413 }
414 DCHECK(!PyGILState_Check());
415 auto rrefPtr = remoteTorchscript(
416 dstWorkerName,
417 qualifiedName,
418 functionSchema,
419 stack,
420 rpcTimeoutSeconds,
421 isAsyncExecution);
422 return PyRRef(rrefPtr);
423}
424
425} // namespace rpc
426} // namespace distributed
427} // namespace torch
428