1#include <torch/csrc/python_headers.h>
2
3#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
4#include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
5#include <torch/csrc/distributed/rpc/py_rref.h>
6#include <torch/csrc/distributed/rpc/python_functions.h>
7#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
8#include <torch/csrc/distributed/rpc/request_callback_impl.h>
9#include <torch/csrc/distributed/rpc/rpc_agent.h>
10#include <torch/csrc/distributed/rpc/rref_context.h>
11#include <torch/csrc/distributed/rpc/tensorpipe_agent.h>
12#include <torch/csrc/distributed/rpc/torchscript_functions.h>
13#include <torch/csrc/distributed/rpc/types.h>
14#include <torch/csrc/jit/python/pybind_utils.h>
15#include <torch/csrc/utils/object_ptr.h>
16#include <torch/csrc/utils/pybind.h>
17#include <torch/types.h>
18
19#include <pybind11/chrono.h>
20#include <pybind11/operators.h>
21
22namespace torch {
23namespace distributed {
24namespace rpc {
25
26namespace {
27
28constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000);
29
30template <typename T>
31using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
32
33PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
34 auto rpc_module =
35 THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc"));
36 if (!rpc_module) {
37 throw python_error();
38 }
39
40 auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
41 if (!torch_C_module) {
42 throw python_error();
43 }
44
45 auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
46 auto m =
47 torch_C_m.def_submodule("_distributed_rpc", "distributed rpc bindings");
48
49 auto module = py::handle(m).cast<py::module>();
50
51 auto rpcBackendOptions =
52 shared_ptr_class_<RpcBackendOptions>(
53 module,
54 "RpcBackendOptions",
55 R"(An abstract structure encapsulating the options passed into the RPC
56 backend. An instance of this class can be passed in to
57 :meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC
58 with specific configurations, such as the RPC timeout and
59 ``init_method`` to be used. )")
60 .def(py::init<>())
61 .def(
62 py::init<float, std::string>(),
63 py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
64 py::arg("init_method") = kDefaultInitMethod)
65 .def_readwrite(
66 "rpc_timeout",
67 &RpcBackendOptions::rpcTimeoutSeconds,
68 R"(A float indicating the timeout to use for all
69 RPCs. If an RPC does not complete in this timeframe, it will
70 complete with an exception indicating that it has timed out.)")
71 .def_readwrite(
72 "init_method",
73 &RpcBackendOptions::initMethod,
74 R"(URL specifying how to initialize the process group.
75 Default is ``env://``)");
76
77 // The following C++ constants need to be cast so they can be used from
78 // python.
79 module.attr("_DEFAULT_RPC_TIMEOUT_SEC") = py::cast(kDefaultRpcTimeoutSeconds);
80 module.attr("_UNSET_RPC_TIMEOUT") = py::cast(kUnsetRpcTimeout);
81 module.attr("_DEFAULT_INIT_METHOD") = py::cast(kDefaultInitMethod);
82
83 auto workerInfo =
84 shared_ptr_class_<WorkerInfo>(
85 module,
86 "WorkerInfo",
87 R"(A structure that encapsulates information of a worker in the system.
88 Contains the name and ID of the worker. This class is not meant to
89 be constructed directly, rather, an instance can be retrieved
90 through :meth:`~torch.distributed.rpc.get_worker_info` and the
91 result can be passed in to functions such as
92 :meth:`~torch.distributed.rpc.rpc_sync`, :meth:`~torch.distributed.rpc.rpc_async`,
93 :meth:`~torch.distributed.rpc.remote` to avoid copying a string on
94 every invocation.)")
95 .def(
96 py::init<std::string, worker_id_t>(),
97 py::arg("name"),
98 py::arg("id"))
99 .def_readonly(
100 "name", &WorkerInfo::name_, R"(The name of the worker.)")
101 .def_readonly(
102 "id",
103 &WorkerInfo::id_,
104 R"(Globally unique id to identify the worker.)")
105 .def("__eq__", &WorkerInfo::operator==, py::is_operator())
106 // pybind11 suggests the syntax .def(hash(py::self)), with the
107 // unqualified "hash" function call. However the
108 // argument-dependent lookup for the function "hash" doesn't get
109 // triggered in this context because it conflicts with the struct
110 // c10::hash, so we need to use the qualified name
111 // py::detail::hash, which unfortunately is in a detail namespace.
112 .def(py::detail::hash(py::self)) // NOLINT
113 .def(
114 "__repr__",
115 [](const WorkerInfo& workerInfo) {
116 std::ostringstream os;
117 os << workerInfo;
118 return os.str();
119 })
120 .def(py::pickle(
121 /* __getstate__ */
122 [](const WorkerInfo& workerInfo) {
123 return py::make_tuple(workerInfo.name_, workerInfo.id_);
124 },
125 /* __setstate__ */
126 [](py::tuple t) {
127 TORCH_CHECK(t.size() == 2, "Invalid WorkerInfo state.");
128
129 WorkerInfo info(
130 t[0].cast<std::string>(), t[1].cast<worker_id_t>());
131 return info;
132 }));
133
134 auto rpcAgent =
135 shared_ptr_class_<RpcAgent>(module, "RpcAgent")
136 .def(
137 "join",
138 &RpcAgent::join,
139 py::call_guard<py::gil_scoped_release>(),
140 py::arg("shutdown") = false,
141 py::arg("timeout") = 0)
142 .def(
143 "sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>())
144 .def(
145 "shutdown",
146 &RpcAgent::shutdown,
147 py::call_guard<py::gil_scoped_release>())
148 .def(
149 "get_worker_info",
150 (const WorkerInfo& (RpcAgent::*)(void) const) &
151 RpcAgent::getWorkerInfo,
152 py::call_guard<py::gil_scoped_release>())
153 .def(
154 "get_worker_info",
155 (const WorkerInfo& (RpcAgent::*)(const std::string&) const) &
156 RpcAgent::getWorkerInfo,
157 py::call_guard<py::gil_scoped_release>())
158 .def(
159 "get_worker_infos",
160 &RpcAgent::getWorkerInfos,
161 py::call_guard<py::gil_scoped_release>())
162 .def(
163 "_get_device_map",
164 &RpcAgent::getDeviceMap,
165 py::call_guard<py::gil_scoped_release>())
166 .def(
167 "get_debug_info",
168 &RpcAgent::getDebugInfo,
169 py::call_guard<py::gil_scoped_release>())
170 .def(
171 "get_metrics",
172 &RpcAgent::getMetrics,
173 py::call_guard<py::gil_scoped_release>());
174
175 auto pyRRef =
176 shared_ptr_class_<PyRRef>(module, "PyRRef", R"(
177 A class encapsulating a reference to a value of some type on a remote
178 worker. This handle will keep the referenced remote value alive on the
179 worker. A ``UserRRef`` will be deleted when 1) no references to it in
180 both the application code and in the local RRef context, or 2) the
181 application has called a graceful shutdown. Invoking methods on a
182 deleted RRef leads to undefined behaviors. RRef implementation only
183 offers best-effort error detection, and applications should not use
184 ``UserRRefs`` after ``rpc.shutdown()``.
185
186 .. warning::
187 RRefs can only be serialized and deserialized by the RPC module.
188 Serializing and deserializing RRefs without RPC (e.g., Python
189 pickle, torch :meth:`~torch.save` / :meth:`~torch.load`,
190 JIT :meth:`~torch.jit.save` / :meth:`~torch.jit.load`, etc.) will
191 lead to errors.
192
193 Args:
194 value (object): The value to be wrapped by this RRef.
195 type_hint (Type, optional): Python type that should be passed to
196 ``TorchScript`` compiler as type hint for ``value``.
197
198 Example::
199 Following examples skip RPC initialization and shutdown code
200 for simplicity. Refer to RPC docs for those details.
201
202 1. Create an RRef using rpc.remote
203
204 >>> import torch
205 >>> import torch.distributed.rpc as rpc
206 >>> rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
207 >>> # get a copy of value from the RRef
208 >>> x = rref.to_here()
209
210 2. Create an RRef from a local object
211
212 >>> import torch
213 >>> from torch.distributed.rpc import RRef
214 >>> x = torch.zeros(2, 2)
215 >>> rref = RRef(x)
216
217 3. Share an RRef with other workers
218
219 >>> # On both worker0 and worker1:
220 >>> def f(rref):
221 >>> return rref.to_here() + 1
222
223 >>> # On worker0:
224 >>> import torch
225 >>> import torch.distributed.rpc as rpc
226 >>> from torch.distributed.rpc import RRef
227 >>> rref = RRef(torch.zeros(2, 2))
228 >>> # the following RPC shares the rref with worker1, reference
229 >>> # count is automatically updated.
230 >>> rpc.rpc_sync("worker1", f, args=(rref,))
231 )")
232 .def(
233 py::init<const py::object&, const py::object&>(),
234 py::arg("value"),
235 py::arg("type_hint") = py::none())
236 .def(
237 // not releasing GIL here to avoid context switch on getters
238 "is_owner",
239 &PyRRef::isOwner,
240 R"(
241 Returns whether or not the current node is the owner of this
242 ``RRef``.
243 )")
244 .def(
245 "confirmed_by_owner",
246 &PyRRef::confirmedByOwner,
247 R"(
248 Returns whether this ``RRef`` has been confirmed by the owner.
249 ``OwnerRRef`` always returns true, while ``UserRRef`` only
250 returns true when the owner knowns about this ``UserRRef``.
251 )")
252 .def(
253 // not releasing GIL here to avoid context switch on getters
254 "owner",
255 &PyRRef::owner,
256 R"(
257 Returns worker information of the node that owns this ``RRef``.
258 )")
259 .def(
260 // not releasing GIL here to avoid context switch on getters
261 "owner_name",
262 &PyRRef::ownerName,
263 R"(
264 Returns worker name of the node that owns this ``RRef``.
265 )")
266 .def(
267 "to_here",
268 &PyRRef::toHere,
269 py::arg("timeout") = py::cast(kUnsetRpcTimeout),
270 py::call_guard<py::gil_scoped_release>(),
271 R"(
272 Blocking call that copies the value of the RRef from the owner
273 to the local node and returns it. If the current node is the
274 owner, returns a reference to the local value.
275
276 Args:
277 timeout (float, optional): Timeout for ``to_here``. If
278 the call does not complete within this timeframe, an
279 exception indicating so will be raised. If this
280 argument is not provided, the default RPC timeout
281 (60s) will be used.
282 )")
283 .def(
284 "local_value",
285 &PyRRef::localValue,
286 py::call_guard<py::gil_scoped_release>(),
287 R"(
288 If the current node is the owner, returns a reference to the
289 local value. Otherwise, throws an exception.
290 )")
291 .def(
292 "rpc_sync",
293 [](const PyRRef& self, float timeoutSeconds) {
294 return self.createRRefProxy(
295 RRefProxyType::RPC_SYNC, timeoutSeconds);
296 },
297 py::arg("timeout") = kUnsetRpcTimeout,
298 py::call_guard<py::gil_scoped_release>(),
299 R"(
300 Create a helper proxy to easily launch an ``rpc_sync`` using
301 the owner of the RRef as the destination to run functions on
302 the object referenced by this RRef. More specifically,
303 ``rref.rpc_sync().func_name(*args, **kwargs)`` is the same as
304 the following:
305
306 >>> def run(rref, func_name, args, kwargs):
307 >>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
308 >>>
309 >>> rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs))
310
311 Args:
312 timeout (float, optional): Timeout for ``rref.rpc_sync()``.
313 If the call does not complete within this timeframe, an
314 exception indicating so will be raised. If this argument
315 is not provided, the default RPC timeout will be used.
316
317 Example::
318 >>> from torch.distributed import rpc
319 >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
320 >>> rref.rpc_sync().size() # returns torch.Size([2, 2])
321 >>> rref.rpc_sync().view(1, 4) # returns tensor([[1., 1., 1., 1.]])
322 )")
323 .def(
324 "rpc_async",
325 [](const PyRRef& self, float timeoutSeconds) {
326 return self.createRRefProxy(
327 RRefProxyType::RPC_ASYNC, timeoutSeconds);
328 },
329 py::arg("timeout") = kUnsetRpcTimeout,
330 py::call_guard<py::gil_scoped_release>(),
331 R"(
332 Create a helper proxy to easily launch an ``rpc_async`` using
333 the owner of the RRef as the destination to run functions on
334 the object referenced by this RRef. More specifically,
335 ``rref.rpc_async().func_name(*args, **kwargs)`` is the same as
336 the following:
337
338 >>> def run(rref, func_name, args, kwargs):
339 >>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
340 >>>
341 >>> rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs))
342
343 Args:
344 timeout (float, optional): Timeout for ``rref.rpc_async()``.
345 If the call does not complete within this timeframe, an
346 exception indicating so will be raised. If this argument
347 is not provided, the default RPC timeout will be used.
348
349 Example::
350 >>> from torch.distributed import rpc
351 >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
352 >>> rref.rpc_async().size().wait() # returns torch.Size([2, 2])
353 >>> rref.rpc_async().view(1, 4).wait() # returns tensor([[1., 1., 1., 1.]])
354 )")
355 .def(
356 "remote",
357 [](const PyRRef& self, float timeoutSeconds) {
358 return self.createRRefProxy(
359 RRefProxyType::REMOTE, timeoutSeconds);
360 },
361 py::arg("timeout") = kUnsetRpcTimeout,
362 py::call_guard<py::gil_scoped_release>(),
363 R"(
364 Create a helper proxy to easily launch a ``remote`` using
365 the owner of the RRef as the destination to run functions on
366 the object referenced by this RRef. More specifically,
367 ``rref.remote().func_name(*args, **kwargs)`` is the same as
368 the following:
369
370 >>> def run(rref, func_name, args, kwargs):
371 >>> return getattr(rref.local_value(), func_name)(*args, **kwargs)
372 >>>
373 >>> rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs))
374
375 Args:
376 timeout (float, optional): Timeout for ``rref.remote()``. If
377 the creation of this :class:`~torch.distributed.rpc.RRef`
378 is not successfully completed within the timeout, then the
379 next time there is an attempt to use the RRef
380 (such as ``to_here``), a timeout will be raised. If not
381 provided, the default RPC timeout will be used. Please see
382 ``rpc.remote()`` for specific timeout semantics for
383 :class:`~torch.distributed.rpc.RRef`.
384
385 Example::
386 >>> from torch.distributed import rpc
387 >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1))
388 >>> rref.remote().size().to_here() # returns torch.Size([2, 2])
389 >>> rref.remote().view(1, 4).to_here() # returns tensor([[1., 1., 1., 1.]])
390 )")
391 .def(
392 py::pickle(
393 /* __getstate__ */
394 [](const PyRRef& /* unused */) {
395 TORCH_CHECK(
396 false,
397 "Can not pickle rref in python pickler, rref can only be "
398 "pickled when using RPC");
399 // Note that this return has no meaning since we always
400 // throw, it's only here to satisfy Pybind API's
401 // requirement.
402 return py::make_tuple();
403 },
404 /* __setstate__ */
405 [](py::tuple /* unused */) { // NOLINT
406 TORCH_CHECK(
407 false,
408 "Can not unpickle rref in python pickler, rref can only be "
409 "unpickled when using RPC");
410 // Note that this return has no meaning since we always
411 // throw, it's only here to satisfy PyBind's API
412 // requirement.
413 return PyRRef(
414 py::cast<py::none>(Py_None),
415 py::cast<py::none>(Py_None));
416 }),
417 py::call_guard<py::gil_scoped_release>())
418 .def(
419 "_serialize",
420 &PyRRef::pickle,
421 py::call_guard<py::gil_scoped_release>())
422 .def_static(
423 "_deserialize",
424 &PyRRef::unpickle,
425 py::call_guard<py::gil_scoped_release>())
426 .def(
427 "_get_type",
428 // Intentionally not releasing GIL, as most accesses just
429 // retrieve cached type py::object
430 &PyRRef::getRRefType,
431 py::arg("timeout") = kUnsetRpcTimeout,
432 py::arg("blocking") = true,
433 R"(
434 If ``blocking=True``, returns the type of the data object
435 referenced by this ``RRef``. On the owner, this is same as
436 ``type(rref.local_value())``. Otherwise, returns a future to
437 this result. On a user, this will trigger an RPC to fetch the
438 ``type`` object from the owner. After this function is run
439 once, the ``type`` object is cached by the ``RRef``, and
440 subsequent invocations no longer trigger RPC. Note that this is
441 true regardless of the ``blocking`` argument of subsequent
442 calls.
443
444 Args:
445 rref (torch.distributed.rpc.RRef): The RRef to get type of.
446 timeout (float, optional): Timeout, in seconds for
447 ``_get_type``. If the call does not complete within
448 this timeframe, an exception indicating so will be
449 raised. If this argument is not provided, the default
450 RPC timeout will be used.
451 blocking (bool, optional): Whether to synchronously wait on
452 the RPC triggered by the first call and return the
453 type. If ``False``, will return a future. Default is
454 ``True``.
455 )")
456 .def(
457 "_get_future",
458 [](const PyRRef& self) {
459 return std::make_shared<jit::PythonFutureWrapper>(
460 self.getFuture());
461 },
462 py::call_guard<py::gil_scoped_release>(),
463 R"(
464 Returns the future that corresponds to the creation of this RRef
465 on the remote node. This is for internal use cases such as profiling
466 only.
467 )")
468 .def(
469 "_get_profiling_future",
470 [](const PyRRef& self) {
471 return std::make_shared<jit::PythonFutureWrapper>(
472 self.getProfilingFuture());
473 },
474 py::call_guard<py::gil_scoped_acquire>(),
475 R"(
476 Returns future that completes when the profiling event corresponding
477 to the creation of this RRef on the remote node has been recorded.
478 )")
479 .def(
480 "_set_profiling_future",
481 [](PyRRef& self,
482 const std::shared_ptr<jit::PythonFutureWrapper>&
483 wrappedFuture) {
484 self.setProfilingFuture(wrappedFuture->fut);
485 },
486 py::call_guard<py::gil_scoped_acquire>(),
487 R"(
488 Set future that is completed when the profiling event corresponding
489 to the creation of this RRef on the remote node has been recorded.
490 )")
491 .def(
492 "backward",
493 [](PyRRef& self,
494 int64_t dist_autograd_ctx_id,
495 bool retain_graph) {
496 self.backward(dist_autograd_ctx_id, retain_graph);
497 },
498 py::arg("dist_autograd_ctx_id") = -1,
499 py::arg("retain_graph") = false,
500 py::call_guard<py::gil_scoped_release>(),
501 R"(
502 Runs the backward pass using the RRef as the root of the
503 backward pass. If ``dist_autograd_ctx_id`` is provided,
504 we perform a distributed backward pass using the provided
505 ctx_id starting from the owner of the RRef. In this case,
506 :meth:`~torch.distributed.autograd.get_gradients` should be
507 used to retrieve the gradients. If ``dist_autograd_ctx_id``
508 is ``None``, it is assumed that this is a local autograd graph
509 and we only perform a local backward pass. In the local case,
510 the node calling this API has to be the owner of the RRef.
511 The value of the RRef is expected to be a scalar Tensor.
512
513 Args:
514 dist_autograd_ctx_id (int, optional): The distributed
515 autograd context id for which we should retrieve the
516 gradients (default: -1).
517 retain_graph(bool, optional): If ``False``, the graph used to
518 compute the grad will be freed. Note that in nearly all
519 cases setting this option to ``True`` is not needed and
520 often can be worked around in a much more efficient way.
521 Usually, you need to set this to ``True`` to run backward
522 multiple times (default: False).
523
524 Example::
525 >>> import torch.distributed.autograd as dist_autograd
526 >>> with dist_autograd.context() as context_id:
527 >>> rref.backward(context_id)
528 )")
529 // not releasing GIL to avoid context switch
530 .def("__repr__", &PyRRef::str);
531
532#ifdef USE_TENSORPIPE
533
534 // Base class: torch.distributed.rpc.RpcBackendOptions.
535 py::class_<TensorPipeRpcBackendOptions>(
536 module, "_TensorPipeRpcBackendOptionsBase", rpcBackendOptions)
537 .def(
538 py::init<
539 int,
540 optional<std::vector<std::string>>,
541 optional<std::vector<std::string>>,
542 float,
543 std::string,
544 std::unordered_map<std::string, DeviceMap>,
545 std::vector<c10::Device>>(),
546 py::arg("num_worker_threads") = kDefaultNumWorkerThreads,
547 py::arg("_transports") = optional<std::vector<std::string>>(),
548 py::arg("_channels") = optional<std::vector<std::string>>(),
549 py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
550 py::arg("init_method") = kDefaultInitMethod,
551 py::arg("device_maps") = std::unordered_map<std::string, DeviceMap>(),
552 py::arg("devices") = std::vector<c10::Device>())
553 .def_readwrite(
554 "num_worker_threads",
555 &TensorPipeRpcBackendOptions::numWorkerThreads,
556 R"(
557 The number of threads in the thread-pool used by
558 :class:`~torch.distributed.rpc.TensorPipeAgent` to execute
559 requests.
560 )")
561 .def_readwrite(
562 "device_maps",
563 &TensorPipeRpcBackendOptions::deviceMaps,
564 R"(The device map locations.)")
565 .def_readwrite(
566 "devices",
567 &TensorPipeRpcBackendOptions::devices,
568 R"(All devices used by the local agent.)")
569 .def("_set_device_map", &TensorPipeRpcBackendOptions::setDeviceMap);
570
571 module.attr("_DEFAULT_NUM_WORKER_THREADS") =
572 py::cast(kDefaultNumWorkerThreads);
573
574 shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent", rpcAgent)
575 .def(
576 py::init(
577 [](const c10::intrusive_ptr<::c10d::Store>& store,
578 std::string selfName,
579 worker_id_t selfId,
580 optional<int> worldSize,
581 TensorPipeRpcBackendOptions opts,
582 std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
583 std::vector<c10::Device> devices) {
584 return std::shared_ptr<TensorPipeAgent>(
585 new TensorPipeAgent(
586 store,
587 std::move(selfName),
588 selfId,
589 worldSize,
590 std::move(opts),
591 std::move(reverseDeviceMaps),
592 std::move(devices),
593 std::make_unique<RequestCallbackImpl>()),
594 impl::destroy_without_gil<TensorPipeAgent>);
595 }),
596 py::arg("store"),
597 py::arg("name"),
598 py::arg("rank"),
599 py::arg("world_size"),
600 py::arg("rpc_backend_options"),
601 py::arg("reverse_device_maps"),
602 py::arg("devices"))
603 .def(
604 "join",
605 &TensorPipeAgent::join,
606 py::call_guard<py::gil_scoped_release>(),
607 py::arg("shutdown") = false,
608 py::arg("timeout") = 0)
609 .def(
610 "shutdown",
611 &TensorPipeAgent::shutdown,
612 py::call_guard<py::gil_scoped_release>())
613 .def(
614 "get_worker_info",
615 (const WorkerInfo& (TensorPipeAgent::*)(void) const) &
616 RpcAgent::getWorkerInfo,
617 py::call_guard<py::gil_scoped_release>())
618 .def(
619 "get_worker_info",
620 (const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) &
621 TensorPipeAgent::getWorkerInfo,
622 py::call_guard<py::gil_scoped_release>())
623 .def(
624 "get_worker_info",
625 (const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) &
626 TensorPipeAgent::getWorkerInfo,
627 py::call_guard<py::gil_scoped_release>())
628 .def(
629 "get_worker_infos",
630 (std::vector<WorkerInfo>(TensorPipeAgent::*)() const) &
631 TensorPipeAgent::getWorkerInfos,
632 py::call_guard<py::gil_scoped_release>())
633 .def(
634 "_get_device_map",
635 (DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst) const) &
636 TensorPipeAgent::getDeviceMap,
637 py::call_guard<py::gil_scoped_release>())
638 .def(
639 "_get_backend_options",
640 &TensorPipeAgent::getBackendOptions,
641 py::call_guard<py::gil_scoped_release>())
642 .def(
643 "_update_group_membership",
644 &TensorPipeAgent::updateGroupMembership,
645 py::call_guard<py::gil_scoped_release>())
646 .def_readonly("is_static_group", &TensorPipeAgent::isStaticGroup_)
647 .def_property_readonly("store", &TensorPipeAgent::getStore);
648
649#endif // USE_TENSORPIPE
650
651 module.def("_is_current_rpc_agent_set", &RpcAgent::isCurrentRpcAgentSet);
652
653 module.def("_get_current_rpc_agent", &RpcAgent::getCurrentRpcAgent);
654
655 module.def(
656 "_set_and_start_rpc_agent",
657 [](const std::shared_ptr<RpcAgent>& rpcAgent) {
658 RpcAgent::setCurrentRpcAgent(rpcAgent);
659 // Initializing typeResolver inside RpcAgent constructor will make
660 // RpcAgent have python dependency. To avoid RpcAgent to have python
661 // dependency, setTypeResolver() here.
662 std::shared_ptr<TypeResolver> typeResolver =
663 std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) {
664 auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr(
665 qn.qualifiedName());
666 return c10::StrongTypePtr(
667 PythonRpcHandler::getInstance().jitCompilationUnit(),
668 std::move(typePtr));
669 });
670 rpcAgent->setTypeResolver(typeResolver);
671 rpcAgent->start();
672 },
673 py::call_guard<py::gil_scoped_release>());
674
675 module.def(
676 "_reset_current_rpc_agent",
677 []() { RpcAgent::setCurrentRpcAgent(nullptr); },
678 py::call_guard<py::gil_scoped_release>());
679
680 module.def(
681 "_delete_all_user_and_unforked_owner_rrefs",
682 [](std::chrono::milliseconds timeoutMillis) {
683 RRefContext::getInstance().delAllUsersAndUnforkedOwners(timeoutMillis);
684 },
685 py::arg("timeout") = kDeleteAllUsersTimeout,
686 py::call_guard<py::gil_scoped_release>());
687
688 module.def("_destroy_rref_context", [](bool ignoreRRefLeak) {
689 // NB: do not release GIL in the function. The destroyInstance() method
690 // returns a list of deleted OwnerRRefs that hold py::object instances.
691 // Clearing those OwnerRRefs are likely to trigger Python deref, which
692 // requires GIL.
693 RRefContext::getInstance().destroyInstance(ignoreRRefLeak).clear();
694 });
695
696 module.def("_rref_context_get_debug_info", []() {
697 return RRefContext::getInstance().getDebugInfo();
698 });
699
700 module.def(
701 "_cleanup_python_rpc_handler",
702 []() { PythonRpcHandler::getInstance().cleanup(); },
703 py::call_guard<py::gil_scoped_release>());
704
705 module.def(
706 "_invoke_rpc_builtin",
707 [](const WorkerInfo& dst,
708 const std::string& opName,
709 const float rpcTimeoutSeconds,
710 const py::args& args,
711 const py::kwargs& kwargs) {
712 return std::make_shared<jit::PythonFutureWrapper>(
713 pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds));
714 },
715 py::call_guard<py::gil_scoped_acquire>());
716
717 module.def(
718 "_invoke_rpc_python_udf",
719 [](const WorkerInfo& dst,
720 std::string& pickledPythonUDF,
721 std::vector<torch::Tensor>& tensors,
722 const float rpcTimeoutSeconds,
723 const bool isAsyncExecution) {
724 return std::make_shared<jit::PythonFutureWrapper>(pyRpcPythonUdf(
725 dst,
726 pickledPythonUDF,
727 tensors,
728 rpcTimeoutSeconds,
729 isAsyncExecution));
730 },
731 py::call_guard<py::gil_scoped_release>());
732
733 module.def(
734 "_invoke_rpc_torchscript",
735 [](const std::string& dstWorkerName,
736 const std::string& qualifiedNameStr,
737 const py::tuple& argsTuple,
738 const py::dict& kwargsDict,
739 const float rpcTimeoutSeconds,
740 const bool isAsyncExecution) {
741 return std::make_shared<jit::PythonFutureWrapper>(pyRpcTorchscript(
742 dstWorkerName,
743 qualifiedNameStr,
744 argsTuple,
745 kwargsDict,
746 rpcTimeoutSeconds,
747 isAsyncExecution));
748 },
749 py::call_guard<py::gil_scoped_release>());
750
751 module.def(
752 "_invoke_remote_builtin",
753 &pyRemoteBuiltin,
754 py::call_guard<py::gil_scoped_acquire>());
755
756 module.def(
757 "_invoke_remote_python_udf",
758 &pyRemotePythonUdf,
759 py::call_guard<py::gil_scoped_release>());
760
761 module.def(
762 "_invoke_remote_torchscript",
763 &pyRemoteTorchscript,
764 py::call_guard<py::gil_scoped_release>());
765
766 module.def(
767 "get_rpc_timeout",
768 []() {
769 return RpcAgent::getCurrentRpcAgent()->getRpcTimeout().count() /
770 kSecToMsConversion;
771 },
772 R"(
773 Retrieve the default timeout for all RPCs that was set during RPC initialization.
774 The returned value will be in seconds.
775 Returns:
776 ``float`` indicating the RPC timeout in seconds.
777 )");
778
779 module.def(
780 "enable_gil_profiling",
781 [](bool flag) {
782 RpcAgent::getCurrentRpcAgent()->enableGILProfiling(flag);
783 },
784 R"(
785 Set whether GIL wait times should be enabled or not. This incurs a slight
786 overhead cost. Default is disabled for performance reasons.
787
788 Args:
789 flag (bool): True to set GIL profiling, False to disable.
790 )");
791
792 module.def(
793 "_set_rpc_timeout",
794 [](const float rpcTimeoutSeconds) {
795 auto rpcTimeout = std::chrono::milliseconds(
796 static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion));
797 RpcAgent::getCurrentRpcAgent()->setRpcTimeout(rpcTimeout);
798 },
799 R"(
800 Set the default timeout for all RPCs. The input unit is expected to be
801 in seconds. If an RPC is not completed within this time, an exception
802 indicating it has timed out will be raised. To control timeout for
803 specific RPCs, a timeout parameter can be passed into
804 :meth:`~torch.distributed.rpc.rpc_sync` and
805 :meth:`~torch.distributed.rpc.rpc_async`.
806
807 Args:
808 rpcTimeoutSeconds (float): Timeout value in seconds.
809 )");
810
811 module.def(
812 "_enable_server_process_global_profiler",
813 &profiler::processglobal::enableServer);
814 module.def(
815 "_disable_server_process_global_profiler",
816 &profiler::processglobal::disableServer);
817
818 module.def("_set_profiler_node_id", &at::RecordFunction::setDefaultNodeId);
819
820 py::class_<
821 RemoteProfilerManager,
822 std::unique_ptr<RemoteProfilerManager, py::nodelete>>(
823 module, "RemoteProfilerManager")
824 .def("set_current_profiling_key", [](const std::string& key) {
825 auto& inst = RemoteProfilerManager::getInstance();
826 inst.setCurrentKey(key);
827 });
828
829 module.def(
830 "_enable_jit_rref_pickle",
831 &enableJitRRefPickle,
832 R"(
833 Allows ``torch.jit.save`` to save a ``torch.jit.ScriptModule`` with
834 pickled RRefs out of RPC contexts.
835
836
837 .. warning::
838 This is dangerous. If the module contains RRefs, the pickled
839 result must be sent over RPC and get unpickled on the receiving side
840 to restore the module. Otherwise, there will be RRef leaks, which
841 can potentially lead to program hang. When using this API, it is
842 applications responsibility to make sure that the above assumption
843 always holds.
844 )");
845 module.def("_disable_jit_rref_pickle", &disableJitRRefPickle);
846
847 Py_RETURN_TRUE;
848}
849
850} // namespace
851
852static PyMethodDef methods[] = { // NOLINT
853 {"_rpc_init", rpc_init, METH_NOARGS, nullptr},
854 {nullptr, nullptr, 0, nullptr}};
855
856PyMethodDef* python_functions() {
857 return methods;
858}
859
860} // namespace rpc
861} // namespace distributed
862} // namespace torch
863