1 | #include <torch/csrc/python_headers.h> |
2 | |
3 | #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h> |
4 | #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h> |
5 | #include <torch/csrc/distributed/rpc/py_rref.h> |
6 | #include <torch/csrc/distributed/rpc/python_functions.h> |
7 | #include <torch/csrc/distributed/rpc/python_rpc_handler.h> |
8 | #include <torch/csrc/distributed/rpc/request_callback_impl.h> |
9 | #include <torch/csrc/distributed/rpc/rpc_agent.h> |
10 | #include <torch/csrc/distributed/rpc/rref_context.h> |
11 | #include <torch/csrc/distributed/rpc/tensorpipe_agent.h> |
12 | #include <torch/csrc/distributed/rpc/torchscript_functions.h> |
13 | #include <torch/csrc/distributed/rpc/types.h> |
14 | #include <torch/csrc/jit/python/pybind_utils.h> |
15 | #include <torch/csrc/utils/object_ptr.h> |
16 | #include <torch/csrc/utils/pybind.h> |
17 | #include <torch/types.h> |
18 | |
19 | #include <pybind11/chrono.h> |
20 | #include <pybind11/operators.h> |
21 | |
22 | namespace torch { |
23 | namespace distributed { |
24 | namespace rpc { |
25 | |
26 | namespace { |
27 | |
28 | constexpr std::chrono::milliseconds kDeleteAllUsersTimeout(100000); |
29 | |
30 | template <typename T> |
31 | using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>; |
32 | |
33 | PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { |
34 | auto rpc_module = |
35 | THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc" )); |
36 | if (!rpc_module) { |
37 | throw python_error(); |
38 | } |
39 | |
40 | auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C" )); |
41 | if (!torch_C_module) { |
42 | throw python_error(); |
43 | } |
44 | |
45 | auto torch_C_m = py::handle(torch_C_module).cast<py::module>(); |
46 | auto m = |
47 | torch_C_m.def_submodule("_distributed_rpc" , "distributed rpc bindings" ); |
48 | |
49 | auto module = py::handle(m).cast<py::module>(); |
50 | |
51 | auto rpcBackendOptions = |
52 | shared_ptr_class_<RpcBackendOptions>( |
53 | module, |
54 | "RpcBackendOptions" , |
55 | R"(An abstract structure encapsulating the options passed into the RPC |
56 | backend. An instance of this class can be passed in to |
57 | :meth:`~torch.distributed.rpc.init_rpc` in order to initialize RPC |
58 | with specific configurations, such as the RPC timeout and |
59 | ``init_method`` to be used. )" ) |
60 | .def(py::init<>()) |
61 | .def( |
62 | py::init<float, std::string>(), |
63 | py::arg("rpc_timeout" ) = kDefaultRpcTimeoutSeconds, |
64 | py::arg("init_method" ) = kDefaultInitMethod) |
65 | .def_readwrite( |
66 | "rpc_timeout" , |
67 | &RpcBackendOptions::rpcTimeoutSeconds, |
68 | R"(A float indicating the timeout to use for all |
69 | RPCs. If an RPC does not complete in this timeframe, it will |
70 | complete with an exception indicating that it has timed out.)" ) |
71 | .def_readwrite( |
72 | "init_method" , |
73 | &RpcBackendOptions::initMethod, |
74 | R"(URL specifying how to initialize the process group. |
75 | Default is ``env://``)" ); |
76 | |
77 | // The following C++ constants need to be cast so they can be used from |
78 | // python. |
79 | module.attr("_DEFAULT_RPC_TIMEOUT_SEC" ) = py::cast(kDefaultRpcTimeoutSeconds); |
80 | module.attr("_UNSET_RPC_TIMEOUT" ) = py::cast(kUnsetRpcTimeout); |
81 | module.attr("_DEFAULT_INIT_METHOD" ) = py::cast(kDefaultInitMethod); |
82 | |
83 | auto workerInfo = |
84 | shared_ptr_class_<WorkerInfo>( |
85 | module, |
86 | "WorkerInfo" , |
87 | R"(A structure that encapsulates information of a worker in the system. |
88 | Contains the name and ID of the worker. This class is not meant to |
89 | be constructed directly, rather, an instance can be retrieved |
90 | through :meth:`~torch.distributed.rpc.get_worker_info` and the |
91 | result can be passed in to functions such as |
92 | :meth:`~torch.distributed.rpc.rpc_sync`, :meth:`~torch.distributed.rpc.rpc_async`, |
93 | :meth:`~torch.distributed.rpc.remote` to avoid copying a string on |
94 | every invocation.)" ) |
95 | .def( |
96 | py::init<std::string, worker_id_t>(), |
97 | py::arg("name" ), |
98 | py::arg("id" )) |
99 | .def_readonly( |
100 | "name" , &WorkerInfo::name_, R"(The name of the worker.)" ) |
101 | .def_readonly( |
102 | "id" , |
103 | &WorkerInfo::id_, |
104 | R"(Globally unique id to identify the worker.)" ) |
105 | .def("__eq__" , &WorkerInfo::operator==, py::is_operator()) |
106 | // pybind11 suggests the syntax .def(hash(py::self)), with the |
107 | // unqualified "hash" function call. However the |
108 | // argument-dependent lookup for the function "hash" doesn't get |
109 | // triggered in this context because it conflicts with the struct |
110 | // c10::hash, so we need to use the qualified name |
111 | // py::detail::hash, which unfortunately is in a detail namespace. |
112 | .def(py::detail::hash(py::self)) // NOLINT |
113 | .def( |
114 | "__repr__" , |
115 | [](const WorkerInfo& workerInfo) { |
116 | std::ostringstream os; |
117 | os << workerInfo; |
118 | return os.str(); |
119 | }) |
120 | .def(py::pickle( |
121 | /* __getstate__ */ |
122 | [](const WorkerInfo& workerInfo) { |
123 | return py::make_tuple(workerInfo.name_, workerInfo.id_); |
124 | }, |
125 | /* __setstate__ */ |
126 | [](py::tuple t) { |
127 | TORCH_CHECK(t.size() == 2, "Invalid WorkerInfo state." ); |
128 | |
129 | WorkerInfo info( |
130 | t[0].cast<std::string>(), t[1].cast<worker_id_t>()); |
131 | return info; |
132 | })); |
133 | |
134 | auto rpcAgent = |
135 | shared_ptr_class_<RpcAgent>(module, "RpcAgent" ) |
136 | .def( |
137 | "join" , |
138 | &RpcAgent::join, |
139 | py::call_guard<py::gil_scoped_release>(), |
140 | py::arg("shutdown" ) = false, |
141 | py::arg("timeout" ) = 0) |
142 | .def( |
143 | "sync" , &RpcAgent::sync, py::call_guard<py::gil_scoped_release>()) |
144 | .def( |
145 | "shutdown" , |
146 | &RpcAgent::shutdown, |
147 | py::call_guard<py::gil_scoped_release>()) |
148 | .def( |
149 | "get_worker_info" , |
150 | (const WorkerInfo& (RpcAgent::*)(void) const) & |
151 | RpcAgent::getWorkerInfo, |
152 | py::call_guard<py::gil_scoped_release>()) |
153 | .def( |
154 | "get_worker_info" , |
155 | (const WorkerInfo& (RpcAgent::*)(const std::string&) const) & |
156 | RpcAgent::getWorkerInfo, |
157 | py::call_guard<py::gil_scoped_release>()) |
158 | .def( |
159 | "get_worker_infos" , |
160 | &RpcAgent::getWorkerInfos, |
161 | py::call_guard<py::gil_scoped_release>()) |
162 | .def( |
163 | "_get_device_map" , |
164 | &RpcAgent::getDeviceMap, |
165 | py::call_guard<py::gil_scoped_release>()) |
166 | .def( |
167 | "get_debug_info" , |
168 | &RpcAgent::getDebugInfo, |
169 | py::call_guard<py::gil_scoped_release>()) |
170 | .def( |
171 | "get_metrics" , |
172 | &RpcAgent::getMetrics, |
173 | py::call_guard<py::gil_scoped_release>()); |
174 | |
175 | auto pyRRef = |
176 | shared_ptr_class_<PyRRef>(module, "PyRRef" , R"( |
177 | A class encapsulating a reference to a value of some type on a remote |
178 | worker. This handle will keep the referenced remote value alive on the |
179 | worker. A ``UserRRef`` will be deleted when 1) no references to it in |
180 | both the application code and in the local RRef context, or 2) the |
181 | application has called a graceful shutdown. Invoking methods on a |
182 | deleted RRef leads to undefined behaviors. RRef implementation only |
183 | offers best-effort error detection, and applications should not use |
184 | ``UserRRefs`` after ``rpc.shutdown()``. |
185 | |
186 | .. warning:: |
187 | RRefs can only be serialized and deserialized by the RPC module. |
188 | Serializing and deserializing RRefs without RPC (e.g., Python |
189 | pickle, torch :meth:`~torch.save` / :meth:`~torch.load`, |
190 | JIT :meth:`~torch.jit.save` / :meth:`~torch.jit.load`, etc.) will |
191 | lead to errors. |
192 | |
193 | Args: |
194 | value (object): The value to be wrapped by this RRef. |
195 | type_hint (Type, optional): Python type that should be passed to |
196 | ``TorchScript`` compiler as type hint for ``value``. |
197 | |
198 | Example:: |
199 | Following examples skip RPC initialization and shutdown code |
200 | for simplicity. Refer to RPC docs for those details. |
201 | |
202 | 1. Create an RRef using rpc.remote |
203 | |
204 | >>> import torch |
205 | >>> import torch.distributed.rpc as rpc |
206 | >>> rref = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) |
207 | >>> # get a copy of value from the RRef |
208 | >>> x = rref.to_here() |
209 | |
210 | 2. Create an RRef from a local object |
211 | |
212 | >>> import torch |
213 | >>> from torch.distributed.rpc import RRef |
214 | >>> x = torch.zeros(2, 2) |
215 | >>> rref = RRef(x) |
216 | |
217 | 3. Share an RRef with other workers |
218 | |
219 | >>> # On both worker0 and worker1: |
220 | >>> def f(rref): |
221 | >>> return rref.to_here() + 1 |
222 | |
223 | >>> # On worker0: |
224 | >>> import torch |
225 | >>> import torch.distributed.rpc as rpc |
226 | >>> from torch.distributed.rpc import RRef |
227 | >>> rref = RRef(torch.zeros(2, 2)) |
228 | >>> # the following RPC shares the rref with worker1, reference |
229 | >>> # count is automatically updated. |
230 | >>> rpc.rpc_sync("worker1", f, args=(rref,)) |
231 | )" ) |
232 | .def( |
233 | py::init<const py::object&, const py::object&>(), |
234 | py::arg("value" ), |
235 | py::arg("type_hint" ) = py::none()) |
236 | .def( |
237 | // not releasing GIL here to avoid context switch on getters |
238 | "is_owner" , |
239 | &PyRRef::isOwner, |
240 | R"( |
241 | Returns whether or not the current node is the owner of this |
242 | ``RRef``. |
243 | )" ) |
244 | .def( |
245 | "confirmed_by_owner" , |
246 | &PyRRef::confirmedByOwner, |
247 | R"( |
248 | Returns whether this ``RRef`` has been confirmed by the owner. |
249 | ``OwnerRRef`` always returns true, while ``UserRRef`` only |
250 | returns true when the owner knowns about this ``UserRRef``. |
251 | )" ) |
252 | .def( |
253 | // not releasing GIL here to avoid context switch on getters |
254 | "owner" , |
255 | &PyRRef::owner, |
256 | R"( |
257 | Returns worker information of the node that owns this ``RRef``. |
258 | )" ) |
259 | .def( |
260 | // not releasing GIL here to avoid context switch on getters |
261 | "owner_name" , |
262 | &PyRRef::ownerName, |
263 | R"( |
264 | Returns worker name of the node that owns this ``RRef``. |
265 | )" ) |
266 | .def( |
267 | "to_here" , |
268 | &PyRRef::toHere, |
269 | py::arg("timeout" ) = py::cast(kUnsetRpcTimeout), |
270 | py::call_guard<py::gil_scoped_release>(), |
271 | R"( |
272 | Blocking call that copies the value of the RRef from the owner |
273 | to the local node and returns it. If the current node is the |
274 | owner, returns a reference to the local value. |
275 | |
276 | Args: |
277 | timeout (float, optional): Timeout for ``to_here``. If |
278 | the call does not complete within this timeframe, an |
279 | exception indicating so will be raised. If this |
280 | argument is not provided, the default RPC timeout |
281 | (60s) will be used. |
282 | )" ) |
283 | .def( |
284 | "local_value" , |
285 | &PyRRef::localValue, |
286 | py::call_guard<py::gil_scoped_release>(), |
287 | R"( |
288 | If the current node is the owner, returns a reference to the |
289 | local value. Otherwise, throws an exception. |
290 | )" ) |
291 | .def( |
292 | "rpc_sync" , |
293 | [](const PyRRef& self, float timeoutSeconds) { |
294 | return self.createRRefProxy( |
295 | RRefProxyType::RPC_SYNC, timeoutSeconds); |
296 | }, |
297 | py::arg("timeout" ) = kUnsetRpcTimeout, |
298 | py::call_guard<py::gil_scoped_release>(), |
299 | R"( |
300 | Create a helper proxy to easily launch an ``rpc_sync`` using |
301 | the owner of the RRef as the destination to run functions on |
302 | the object referenced by this RRef. More specifically, |
303 | ``rref.rpc_sync().func_name(*args, **kwargs)`` is the same as |
304 | the following: |
305 | |
306 | >>> def run(rref, func_name, args, kwargs): |
307 | >>> return getattr(rref.local_value(), func_name)(*args, **kwargs) |
308 | >>> |
309 | >>> rpc.rpc_sync(rref.owner(), run, args=(rref, func_name, args, kwargs)) |
310 | |
311 | Args: |
312 | timeout (float, optional): Timeout for ``rref.rpc_sync()``. |
313 | If the call does not complete within this timeframe, an |
314 | exception indicating so will be raised. If this argument |
315 | is not provided, the default RPC timeout will be used. |
316 | |
317 | Example:: |
318 | >>> from torch.distributed import rpc |
319 | >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) |
320 | >>> rref.rpc_sync().size() # returns torch.Size([2, 2]) |
321 | >>> rref.rpc_sync().view(1, 4) # returns tensor([[1., 1., 1., 1.]]) |
322 | )" ) |
323 | .def( |
324 | "rpc_async" , |
325 | [](const PyRRef& self, float timeoutSeconds) { |
326 | return self.createRRefProxy( |
327 | RRefProxyType::RPC_ASYNC, timeoutSeconds); |
328 | }, |
329 | py::arg("timeout" ) = kUnsetRpcTimeout, |
330 | py::call_guard<py::gil_scoped_release>(), |
331 | R"( |
332 | Create a helper proxy to easily launch an ``rpc_async`` using |
333 | the owner of the RRef as the destination to run functions on |
334 | the object referenced by this RRef. More specifically, |
335 | ``rref.rpc_async().func_name(*args, **kwargs)`` is the same as |
336 | the following: |
337 | |
338 | >>> def run(rref, func_name, args, kwargs): |
339 | >>> return getattr(rref.local_value(), func_name)(*args, **kwargs) |
340 | >>> |
341 | >>> rpc.rpc_async(rref.owner(), run, args=(rref, func_name, args, kwargs)) |
342 | |
343 | Args: |
344 | timeout (float, optional): Timeout for ``rref.rpc_async()``. |
345 | If the call does not complete within this timeframe, an |
346 | exception indicating so will be raised. If this argument |
347 | is not provided, the default RPC timeout will be used. |
348 | |
349 | Example:: |
350 | >>> from torch.distributed import rpc |
351 | >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) |
352 | >>> rref.rpc_async().size().wait() # returns torch.Size([2, 2]) |
353 | >>> rref.rpc_async().view(1, 4).wait() # returns tensor([[1., 1., 1., 1.]]) |
354 | )" ) |
355 | .def( |
356 | "remote" , |
357 | [](const PyRRef& self, float timeoutSeconds) { |
358 | return self.createRRefProxy( |
359 | RRefProxyType::REMOTE, timeoutSeconds); |
360 | }, |
361 | py::arg("timeout" ) = kUnsetRpcTimeout, |
362 | py::call_guard<py::gil_scoped_release>(), |
363 | R"( |
364 | Create a helper proxy to easily launch a ``remote`` using |
365 | the owner of the RRef as the destination to run functions on |
366 | the object referenced by this RRef. More specifically, |
367 | ``rref.remote().func_name(*args, **kwargs)`` is the same as |
368 | the following: |
369 | |
370 | >>> def run(rref, func_name, args, kwargs): |
371 | >>> return getattr(rref.local_value(), func_name)(*args, **kwargs) |
372 | >>> |
373 | >>> rpc.remote(rref.owner(), run, args=(rref, func_name, args, kwargs)) |
374 | |
375 | Args: |
376 | timeout (float, optional): Timeout for ``rref.remote()``. If |
377 | the creation of this :class:`~torch.distributed.rpc.RRef` |
378 | is not successfully completed within the timeout, then the |
379 | next time there is an attempt to use the RRef |
380 | (such as ``to_here``), a timeout will be raised. If not |
381 | provided, the default RPC timeout will be used. Please see |
382 | ``rpc.remote()`` for specific timeout semantics for |
383 | :class:`~torch.distributed.rpc.RRef`. |
384 | |
385 | Example:: |
386 | >>> from torch.distributed import rpc |
387 | >>> rref = rpc.remote("worker1", torch.add, args=(torch.zeros(2, 2), 1)) |
388 | >>> rref.remote().size().to_here() # returns torch.Size([2, 2]) |
389 | >>> rref.remote().view(1, 4).to_here() # returns tensor([[1., 1., 1., 1.]]) |
390 | )" ) |
391 | .def( |
392 | py::pickle( |
393 | /* __getstate__ */ |
394 | [](const PyRRef& /* unused */) { |
395 | TORCH_CHECK( |
396 | false, |
397 | "Can not pickle rref in python pickler, rref can only be " |
398 | "pickled when using RPC" ); |
399 | // Note that this return has no meaning since we always |
400 | // throw, it's only here to satisfy Pybind API's |
401 | // requirement. |
402 | return py::make_tuple(); |
403 | }, |
404 | /* __setstate__ */ |
405 | [](py::tuple /* unused */) { // NOLINT |
406 | TORCH_CHECK( |
407 | false, |
408 | "Can not unpickle rref in python pickler, rref can only be " |
409 | "unpickled when using RPC" ); |
410 | // Note that this return has no meaning since we always |
411 | // throw, it's only here to satisfy PyBind's API |
412 | // requirement. |
413 | return PyRRef( |
414 | py::cast<py::none>(Py_None), |
415 | py::cast<py::none>(Py_None)); |
416 | }), |
417 | py::call_guard<py::gil_scoped_release>()) |
418 | .def( |
419 | "_serialize" , |
420 | &PyRRef::pickle, |
421 | py::call_guard<py::gil_scoped_release>()) |
422 | .def_static( |
423 | "_deserialize" , |
424 | &PyRRef::unpickle, |
425 | py::call_guard<py::gil_scoped_release>()) |
426 | .def( |
427 | "_get_type" , |
428 | // Intentionally not releasing GIL, as most accesses just |
429 | // retrieve cached type py::object |
430 | &PyRRef::getRRefType, |
431 | py::arg("timeout" ) = kUnsetRpcTimeout, |
432 | py::arg("blocking" ) = true, |
433 | R"( |
434 | If ``blocking=True``, returns the type of the data object |
435 | referenced by this ``RRef``. On the owner, this is same as |
436 | ``type(rref.local_value())``. Otherwise, returns a future to |
437 | this result. On a user, this will trigger an RPC to fetch the |
438 | ``type`` object from the owner. After this function is run |
439 | once, the ``type`` object is cached by the ``RRef``, and |
440 | subsequent invocations no longer trigger RPC. Note that this is |
441 | true regardless of the ``blocking`` argument of subsequent |
442 | calls. |
443 | |
444 | Args: |
445 | rref (torch.distributed.rpc.RRef): The RRef to get type of. |
446 | timeout (float, optional): Timeout, in seconds for |
447 | ``_get_type``. If the call does not complete within |
448 | this timeframe, an exception indicating so will be |
449 | raised. If this argument is not provided, the default |
450 | RPC timeout will be used. |
451 | blocking (bool, optional): Whether to synchronously wait on |
452 | the RPC triggered by the first call and return the |
453 | type. If ``False``, will return a future. Default is |
454 | ``True``. |
455 | )" ) |
456 | .def( |
457 | "_get_future" , |
458 | [](const PyRRef& self) { |
459 | return std::make_shared<jit::PythonFutureWrapper>( |
460 | self.getFuture()); |
461 | }, |
462 | py::call_guard<py::gil_scoped_release>(), |
463 | R"( |
464 | Returns the future that corresponds to the creation of this RRef |
465 | on the remote node. This is for internal use cases such as profiling |
466 | only. |
467 | )" ) |
468 | .def( |
469 | "_get_profiling_future" , |
470 | [](const PyRRef& self) { |
471 | return std::make_shared<jit::PythonFutureWrapper>( |
472 | self.getProfilingFuture()); |
473 | }, |
474 | py::call_guard<py::gil_scoped_acquire>(), |
475 | R"( |
476 | Returns future that completes when the profiling event corresponding |
477 | to the creation of this RRef on the remote node has been recorded. |
478 | )" ) |
479 | .def( |
480 | "_set_profiling_future" , |
481 | [](PyRRef& self, |
482 | const std::shared_ptr<jit::PythonFutureWrapper>& |
483 | wrappedFuture) { |
484 | self.setProfilingFuture(wrappedFuture->fut); |
485 | }, |
486 | py::call_guard<py::gil_scoped_acquire>(), |
487 | R"( |
488 | Set future that is completed when the profiling event corresponding |
489 | to the creation of this RRef on the remote node has been recorded. |
490 | )" ) |
491 | .def( |
492 | "backward" , |
493 | [](PyRRef& self, |
494 | int64_t dist_autograd_ctx_id, |
495 | bool retain_graph) { |
496 | self.backward(dist_autograd_ctx_id, retain_graph); |
497 | }, |
498 | py::arg("dist_autograd_ctx_id" ) = -1, |
499 | py::arg("retain_graph" ) = false, |
500 | py::call_guard<py::gil_scoped_release>(), |
501 | R"( |
502 | Runs the backward pass using the RRef as the root of the |
503 | backward pass. If ``dist_autograd_ctx_id`` is provided, |
504 | we perform a distributed backward pass using the provided |
505 | ctx_id starting from the owner of the RRef. In this case, |
506 | :meth:`~torch.distributed.autograd.get_gradients` should be |
507 | used to retrieve the gradients. If ``dist_autograd_ctx_id`` |
508 | is ``None``, it is assumed that this is a local autograd graph |
509 | and we only perform a local backward pass. In the local case, |
510 | the node calling this API has to be the owner of the RRef. |
511 | The value of the RRef is expected to be a scalar Tensor. |
512 | |
513 | Args: |
514 | dist_autograd_ctx_id (int, optional): The distributed |
515 | autograd context id for which we should retrieve the |
516 | gradients (default: -1). |
517 | retain_graph(bool, optional): If ``False``, the graph used to |
518 | compute the grad will be freed. Note that in nearly all |
519 | cases setting this option to ``True`` is not needed and |
520 | often can be worked around in a much more efficient way. |
521 | Usually, you need to set this to ``True`` to run backward |
522 | multiple times (default: False). |
523 | |
524 | Example:: |
525 | >>> import torch.distributed.autograd as dist_autograd |
526 | >>> with dist_autograd.context() as context_id: |
527 | >>> rref.backward(context_id) |
528 | )" ) |
529 | // not releasing GIL to avoid context switch |
530 | .def("__repr__" , &PyRRef::str); |
531 | |
532 | #ifdef USE_TENSORPIPE |
533 | |
534 | // Base class: torch.distributed.rpc.RpcBackendOptions. |
535 | py::class_<TensorPipeRpcBackendOptions>( |
536 | module, "_TensorPipeRpcBackendOptionsBase" , rpcBackendOptions) |
537 | .def( |
538 | py::init< |
539 | int, |
540 | optional<std::vector<std::string>>, |
541 | optional<std::vector<std::string>>, |
542 | float, |
543 | std::string, |
544 | std::unordered_map<std::string, DeviceMap>, |
545 | std::vector<c10::Device>>(), |
546 | py::arg("num_worker_threads" ) = kDefaultNumWorkerThreads, |
547 | py::arg("_transports" ) = optional<std::vector<std::string>>(), |
548 | py::arg("_channels" ) = optional<std::vector<std::string>>(), |
549 | py::arg("rpc_timeout" ) = kDefaultRpcTimeoutSeconds, |
550 | py::arg("init_method" ) = kDefaultInitMethod, |
551 | py::arg("device_maps" ) = std::unordered_map<std::string, DeviceMap>(), |
552 | py::arg("devices" ) = std::vector<c10::Device>()) |
553 | .def_readwrite( |
554 | "num_worker_threads" , |
555 | &TensorPipeRpcBackendOptions::numWorkerThreads, |
556 | R"( |
557 | The number of threads in the thread-pool used by |
558 | :class:`~torch.distributed.rpc.TensorPipeAgent` to execute |
559 | requests. |
560 | )" ) |
561 | .def_readwrite( |
562 | "device_maps" , |
563 | &TensorPipeRpcBackendOptions::deviceMaps, |
564 | R"(The device map locations.)" ) |
565 | .def_readwrite( |
566 | "devices" , |
567 | &TensorPipeRpcBackendOptions::devices, |
568 | R"(All devices used by the local agent.)" ) |
569 | .def("_set_device_map" , &TensorPipeRpcBackendOptions::setDeviceMap); |
570 | |
571 | module.attr("_DEFAULT_NUM_WORKER_THREADS" ) = |
572 | py::cast(kDefaultNumWorkerThreads); |
573 | |
574 | shared_ptr_class_<TensorPipeAgent>(module, "TensorPipeAgent" , rpcAgent) |
575 | .def( |
576 | py::init( |
577 | [](const c10::intrusive_ptr<::c10d::Store>& store, |
578 | std::string selfName, |
579 | worker_id_t selfId, |
580 | optional<int> worldSize, |
581 | TensorPipeRpcBackendOptions opts, |
582 | std::unordered_map<std::string, DeviceMap> reverseDeviceMaps, |
583 | std::vector<c10::Device> devices) { |
584 | return std::shared_ptr<TensorPipeAgent>( |
585 | new TensorPipeAgent( |
586 | store, |
587 | std::move(selfName), |
588 | selfId, |
589 | worldSize, |
590 | std::move(opts), |
591 | std::move(reverseDeviceMaps), |
592 | std::move(devices), |
593 | std::make_unique<RequestCallbackImpl>()), |
594 | impl::destroy_without_gil<TensorPipeAgent>); |
595 | }), |
596 | py::arg("store" ), |
597 | py::arg("name" ), |
598 | py::arg("rank" ), |
599 | py::arg("world_size" ), |
600 | py::arg("rpc_backend_options" ), |
601 | py::arg("reverse_device_maps" ), |
602 | py::arg("devices" )) |
603 | .def( |
604 | "join" , |
605 | &TensorPipeAgent::join, |
606 | py::call_guard<py::gil_scoped_release>(), |
607 | py::arg("shutdown" ) = false, |
608 | py::arg("timeout" ) = 0) |
609 | .def( |
610 | "shutdown" , |
611 | &TensorPipeAgent::shutdown, |
612 | py::call_guard<py::gil_scoped_release>()) |
613 | .def( |
614 | "get_worker_info" , |
615 | (const WorkerInfo& (TensorPipeAgent::*)(void) const) & |
616 | RpcAgent::getWorkerInfo, |
617 | py::call_guard<py::gil_scoped_release>()) |
618 | .def( |
619 | "get_worker_info" , |
620 | (const WorkerInfo& (TensorPipeAgent::*)(const std::string&) const) & |
621 | TensorPipeAgent::getWorkerInfo, |
622 | py::call_guard<py::gil_scoped_release>()) |
623 | .def( |
624 | "get_worker_info" , |
625 | (const WorkerInfo& (TensorPipeAgent::*)(worker_id_t id) const) & |
626 | TensorPipeAgent::getWorkerInfo, |
627 | py::call_guard<py::gil_scoped_release>()) |
628 | .def( |
629 | "get_worker_infos" , |
630 | (std::vector<WorkerInfo>(TensorPipeAgent::*)() const) & |
631 | TensorPipeAgent::getWorkerInfos, |
632 | py::call_guard<py::gil_scoped_release>()) |
633 | .def( |
634 | "_get_device_map" , |
635 | (DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst) const) & |
636 | TensorPipeAgent::getDeviceMap, |
637 | py::call_guard<py::gil_scoped_release>()) |
638 | .def( |
639 | "_get_backend_options" , |
640 | &TensorPipeAgent::getBackendOptions, |
641 | py::call_guard<py::gil_scoped_release>()) |
642 | .def( |
643 | "_update_group_membership" , |
644 | &TensorPipeAgent::updateGroupMembership, |
645 | py::call_guard<py::gil_scoped_release>()) |
646 | .def_readonly("is_static_group" , &TensorPipeAgent::isStaticGroup_) |
647 | .def_property_readonly("store" , &TensorPipeAgent::getStore); |
648 | |
649 | #endif // USE_TENSORPIPE |
650 | |
651 | module.def("_is_current_rpc_agent_set" , &RpcAgent::isCurrentRpcAgentSet); |
652 | |
653 | module.def("_get_current_rpc_agent" , &RpcAgent::getCurrentRpcAgent); |
654 | |
655 | module.def( |
656 | "_set_and_start_rpc_agent" , |
657 | [](const std::shared_ptr<RpcAgent>& rpcAgent) { |
658 | RpcAgent::setCurrentRpcAgent(rpcAgent); |
659 | // Initializing typeResolver inside RpcAgent constructor will make |
660 | // RpcAgent have python dependency. To avoid RpcAgent to have python |
661 | // dependency, setTypeResolver() here. |
662 | std::shared_ptr<TypeResolver> typeResolver = |
663 | std::make_shared<TypeResolver>([&](const c10::QualifiedName& qn) { |
664 | auto typePtr = PythonRpcHandler::getInstance().parseTypeFromStr( |
665 | qn.qualifiedName()); |
666 | return c10::StrongTypePtr( |
667 | PythonRpcHandler::getInstance().jitCompilationUnit(), |
668 | std::move(typePtr)); |
669 | }); |
670 | rpcAgent->setTypeResolver(typeResolver); |
671 | rpcAgent->start(); |
672 | }, |
673 | py::call_guard<py::gil_scoped_release>()); |
674 | |
675 | module.def( |
676 | "_reset_current_rpc_agent" , |
677 | []() { RpcAgent::setCurrentRpcAgent(nullptr); }, |
678 | py::call_guard<py::gil_scoped_release>()); |
679 | |
680 | module.def( |
681 | "_delete_all_user_and_unforked_owner_rrefs" , |
682 | [](std::chrono::milliseconds timeoutMillis) { |
683 | RRefContext::getInstance().delAllUsersAndUnforkedOwners(timeoutMillis); |
684 | }, |
685 | py::arg("timeout" ) = kDeleteAllUsersTimeout, |
686 | py::call_guard<py::gil_scoped_release>()); |
687 | |
688 | module.def("_destroy_rref_context" , [](bool ignoreRRefLeak) { |
689 | // NB: do not release GIL in the function. The destroyInstance() method |
690 | // returns a list of deleted OwnerRRefs that hold py::object instances. |
691 | // Clearing those OwnerRRefs are likely to trigger Python deref, which |
692 | // requires GIL. |
693 | RRefContext::getInstance().destroyInstance(ignoreRRefLeak).clear(); |
694 | }); |
695 | |
696 | module.def("_rref_context_get_debug_info" , []() { |
697 | return RRefContext::getInstance().getDebugInfo(); |
698 | }); |
699 | |
700 | module.def( |
701 | "_cleanup_python_rpc_handler" , |
702 | []() { PythonRpcHandler::getInstance().cleanup(); }, |
703 | py::call_guard<py::gil_scoped_release>()); |
704 | |
705 | module.def( |
706 | "_invoke_rpc_builtin" , |
707 | [](const WorkerInfo& dst, |
708 | const std::string& opName, |
709 | const float rpcTimeoutSeconds, |
710 | const py::args& args, |
711 | const py::kwargs& kwargs) { |
712 | return std::make_shared<jit::PythonFutureWrapper>( |
713 | pyRpcBuiltin(dst, opName, args, kwargs, rpcTimeoutSeconds)); |
714 | }, |
715 | py::call_guard<py::gil_scoped_acquire>()); |
716 | |
717 | module.def( |
718 | "_invoke_rpc_python_udf" , |
719 | [](const WorkerInfo& dst, |
720 | std::string& pickledPythonUDF, |
721 | std::vector<torch::Tensor>& tensors, |
722 | const float rpcTimeoutSeconds, |
723 | const bool isAsyncExecution) { |
724 | return std::make_shared<jit::PythonFutureWrapper>(pyRpcPythonUdf( |
725 | dst, |
726 | pickledPythonUDF, |
727 | tensors, |
728 | rpcTimeoutSeconds, |
729 | isAsyncExecution)); |
730 | }, |
731 | py::call_guard<py::gil_scoped_release>()); |
732 | |
733 | module.def( |
734 | "_invoke_rpc_torchscript" , |
735 | [](const std::string& dstWorkerName, |
736 | const std::string& qualifiedNameStr, |
737 | const py::tuple& argsTuple, |
738 | const py::dict& kwargsDict, |
739 | const float rpcTimeoutSeconds, |
740 | const bool isAsyncExecution) { |
741 | return std::make_shared<jit::PythonFutureWrapper>(pyRpcTorchscript( |
742 | dstWorkerName, |
743 | qualifiedNameStr, |
744 | argsTuple, |
745 | kwargsDict, |
746 | rpcTimeoutSeconds, |
747 | isAsyncExecution)); |
748 | }, |
749 | py::call_guard<py::gil_scoped_release>()); |
750 | |
751 | module.def( |
752 | "_invoke_remote_builtin" , |
753 | &pyRemoteBuiltin, |
754 | py::call_guard<py::gil_scoped_acquire>()); |
755 | |
756 | module.def( |
757 | "_invoke_remote_python_udf" , |
758 | &pyRemotePythonUdf, |
759 | py::call_guard<py::gil_scoped_release>()); |
760 | |
761 | module.def( |
762 | "_invoke_remote_torchscript" , |
763 | &pyRemoteTorchscript, |
764 | py::call_guard<py::gil_scoped_release>()); |
765 | |
766 | module.def( |
767 | "get_rpc_timeout" , |
768 | []() { |
769 | return RpcAgent::getCurrentRpcAgent()->getRpcTimeout().count() / |
770 | kSecToMsConversion; |
771 | }, |
772 | R"( |
773 | Retrieve the default timeout for all RPCs that was set during RPC initialization. |
774 | The returned value will be in seconds. |
775 | Returns: |
776 | ``float`` indicating the RPC timeout in seconds. |
777 | )" ); |
778 | |
779 | module.def( |
780 | "enable_gil_profiling" , |
781 | [](bool flag) { |
782 | RpcAgent::getCurrentRpcAgent()->enableGILProfiling(flag); |
783 | }, |
784 | R"( |
785 | Set whether GIL wait times should be enabled or not. This incurs a slight |
786 | overhead cost. Default is disabled for performance reasons. |
787 | |
788 | Args: |
789 | flag (bool): True to set GIL profiling, False to disable. |
790 | )" ); |
791 | |
792 | module.def( |
793 | "_set_rpc_timeout" , |
794 | [](const float rpcTimeoutSeconds) { |
795 | auto rpcTimeout = std::chrono::milliseconds( |
796 | static_cast<int>(rpcTimeoutSeconds * kSecToMsConversion)); |
797 | RpcAgent::getCurrentRpcAgent()->setRpcTimeout(rpcTimeout); |
798 | }, |
799 | R"( |
800 | Set the default timeout for all RPCs. The input unit is expected to be |
801 | in seconds. If an RPC is not completed within this time, an exception |
802 | indicating it has timed out will be raised. To control timeout for |
803 | specific RPCs, a timeout parameter can be passed into |
804 | :meth:`~torch.distributed.rpc.rpc_sync` and |
805 | :meth:`~torch.distributed.rpc.rpc_async`. |
806 | |
807 | Args: |
808 | rpcTimeoutSeconds (float): Timeout value in seconds. |
809 | )" ); |
810 | |
811 | module.def( |
812 | "_enable_server_process_global_profiler" , |
813 | &profiler::processglobal::enableServer); |
814 | module.def( |
815 | "_disable_server_process_global_profiler" , |
816 | &profiler::processglobal::disableServer); |
817 | |
818 | module.def("_set_profiler_node_id" , &at::RecordFunction::setDefaultNodeId); |
819 | |
820 | py::class_< |
821 | RemoteProfilerManager, |
822 | std::unique_ptr<RemoteProfilerManager, py::nodelete>>( |
823 | module, "RemoteProfilerManager" ) |
824 | .def("set_current_profiling_key" , [](const std::string& key) { |
825 | auto& inst = RemoteProfilerManager::getInstance(); |
826 | inst.setCurrentKey(key); |
827 | }); |
828 | |
829 | module.def( |
830 | "_enable_jit_rref_pickle" , |
831 | &enableJitRRefPickle, |
832 | R"( |
833 | Allows ``torch.jit.save`` to save a ``torch.jit.ScriptModule`` with |
834 | pickled RRefs out of RPC contexts. |
835 | |
836 | |
837 | .. warning:: |
838 | This is dangerous. If the module contains RRefs, the pickled |
839 | result must be sent over RPC and get unpickled on the receiving side |
840 | to restore the module. Otherwise, there will be RRef leaks, which |
841 | can potentially lead to program hang. When using this API, it is |
842 | applications responsibility to make sure that the above assumption |
843 | always holds. |
844 | )" ); |
845 | module.def("_disable_jit_rref_pickle" , &disableJitRRefPickle); |
846 | |
847 | Py_RETURN_TRUE; |
848 | } |
849 | |
850 | } // namespace |
851 | |
852 | static PyMethodDef methods[] = { // NOLINT |
853 | {"_rpc_init" , rpc_init, METH_NOARGS, nullptr}, |
854 | {nullptr, nullptr, 0, nullptr}}; |
855 | |
856 | PyMethodDef* python_functions() { |
857 | return methods; |
858 | } |
859 | |
860 | } // namespace rpc |
861 | } // namespace distributed |
862 | } // namespace torch |
863 | |